diff --git a/elasticsearch_dsl/__init__.py b/elasticsearch_dsl/__init__.py index fd4433c2..23408afd 100644 --- a/elasticsearch_dsl/__init__.py +++ b/elasticsearch_dsl/__init__.py @@ -16,7 +16,7 @@ # under the License. from . import connections -from .aggs import A +from .aggs import A, Agg from .analysis import analyzer, char_filter, normalizer, token_filter, tokenizer from .document import AsyncDocument, Document from .document_base import InnerDoc, M, MetaField, mapped_field @@ -81,7 +81,8 @@ from .function import SF from .index import AsyncIndex, AsyncIndexTemplate, Index, IndexTemplate from .mapping import AsyncMapping, Mapping -from .query import Q +from .query import Q, Query +from .response import AggResponse, Response, UpdateByQueryResponse from .search import ( AsyncEmptySearch, AsyncMultiSearch, @@ -99,6 +100,8 @@ __versionstr__ = ".".join(map(str, VERSION)) __all__ = [ "A", + "Agg", + "AggResponse", "AsyncDocument", "AsyncEmptySearch", "AsyncFacetedSearch", @@ -158,11 +161,13 @@ "Object", "Percolator", "Q", + "Query", "Range", "RangeFacet", "RangeField", "RankFeature", "RankFeatures", + "Response", "SF", "ScaledFloat", "Search", @@ -174,6 +179,7 @@ "TokenCount", "UnknownDslObject", "UpdateByQuery", + "UpdateByQueryResponse", "ValidationException", "analyzer", "char_filter", diff --git a/elasticsearch_dsl/_async/index.py b/elasticsearch_dsl/_async/index.py index b3bb1e64..765e7438 100644 --- a/elasticsearch_dsl/_async/index.py +++ b/elasticsearch_dsl/_async/index.py @@ -38,7 +38,7 @@ def __init__( name: str, template: str, index: Optional["AsyncIndex"] = None, - order: Optional[str] = None, + order: Optional[int] = None, **kwargs: Any, ): if index is None: @@ -100,7 +100,7 @@ def as_template( self, template_name: str, pattern: Optional[str] = None, - order: Optional[str] = None, + order: Optional[int] = None, ) -> AsyncIndexTemplate: # TODO: should we allow pattern to be a top-level arg? # or maybe have an IndexPattern that allows for it and have diff --git a/elasticsearch_dsl/_async/search.py b/elasticsearch_dsl/_async/search.py index 4ab0297b..94fbe289 100644 --- a/elasticsearch_dsl/_async/search.py +++ b/elasticsearch_dsl/_async/search.py @@ -16,7 +16,16 @@ # under the License. import contextlib -from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, cast +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Dict, + Iterator, + List, + Optional, + cast, +) from elasticsearch.exceptions import ApiError from elasticsearch.helpers import async_scan @@ -68,6 +77,7 @@ async def count(self) -> int: query=cast(Optional[Dict[str, Any]], d.get("query", None)), **self._params, ) + return cast(int, resp["count"]) async def execute(self, ignore_cache: bool = False) -> Response[_R]: @@ -175,6 +185,10 @@ class AsyncMultiSearch(MultiSearchBase[_R]): _using: AsyncUsingType + if TYPE_CHECKING: + + def add(self, search: AsyncSearch[_R]) -> Self: ... # type: ignore[override] + async def execute( self, ignore_cache: bool = False, raise_on_error: bool = True ) -> List[Response[_R]]: diff --git a/elasticsearch_dsl/_sync/index.py b/elasticsearch_dsl/_sync/index.py index 9f9cf53e..59508d51 100644 --- a/elasticsearch_dsl/_sync/index.py +++ b/elasticsearch_dsl/_sync/index.py @@ -38,7 +38,7 @@ def __init__( name: str, template: str, index: Optional["Index"] = None, - order: Optional[str] = None, + order: Optional[int] = None, **kwargs: Any, ): if index is None: @@ -94,7 +94,7 @@ def as_template( self, template_name: str, pattern: Optional[str] = None, - order: Optional[str] = None, + order: Optional[int] = None, ) -> IndexTemplate: # TODO: should we allow pattern to be a top-level arg? # or maybe have an IndexPattern that allows for it and have diff --git a/elasticsearch_dsl/_sync/search.py b/elasticsearch_dsl/_sync/search.py index 6ecbf27d..f3906f59 100644 --- a/elasticsearch_dsl/_sync/search.py +++ b/elasticsearch_dsl/_sync/search.py @@ -16,7 +16,7 @@ # under the License. import contextlib -from typing import Any, Dict, Iterator, List, Optional, cast +from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, cast from elasticsearch.exceptions import ApiError from elasticsearch.helpers import scan @@ -68,6 +68,7 @@ def count(self) -> int: query=cast(Optional[Dict[str, Any]], d.get("query", None)), **self._params, ) + return cast(int, resp["count"]) def execute(self, ignore_cache: bool = False) -> Response[_R]: @@ -169,6 +170,10 @@ class MultiSearch(MultiSearchBase[_R]): _using: UsingType + if TYPE_CHECKING: + + def add(self, search: Search[_R]) -> Self: ... # type: ignore[override] + def execute( self, ignore_cache: bool = False, raise_on_error: bool = True ) -> List[Response[_R]]: diff --git a/elasticsearch_dsl/aggs.py b/elasticsearch_dsl/aggs.py index cf571947..ef032f7a 100644 --- a/elasticsearch_dsl/aggs.py +++ b/elasticsearch_dsl/aggs.py @@ -140,7 +140,12 @@ def __iter__(self) -> Iterable[str]: return iter(self.aggs) def _agg( - self, bucket: bool, name: str, agg_type: str, *args: Any, **params: Any + self, + bucket: bool, + name: str, + agg_type: Union[Dict[str, Any], Agg[_R], str], + *args: Any, + **params: Any, ) -> Agg[_R]: agg = self[name] = A(agg_type, *args, **params) @@ -151,14 +156,32 @@ def _agg( else: return self._base - def metric(self, name: str, agg_type: str, *args: Any, **params: Any) -> Agg[_R]: + def metric( + self, + name: str, + agg_type: Union[Dict[str, Any], Agg[_R], str], + *args: Any, + **params: Any, + ) -> Agg[_R]: return self._agg(False, name, agg_type, *args, **params) - def bucket(self, name: str, agg_type: str, *args: Any, **params: Any) -> Agg[_R]: - return self._agg(True, name, agg_type, *args, **params) - - def pipeline(self, name: str, agg_type: str, *args: Any, **params: Any) -> Agg[_R]: - return self._agg(False, name, agg_type, *args, **params) + def bucket( + self, + name: str, + agg_type: Union[Dict[str, Any], Agg[_R], str], + *args: Any, + **params: Any, + ) -> "Bucket[_R]": + return cast("Bucket[_R]", self._agg(True, name, agg_type, *args, **params)) + + def pipeline( + self, + name: str, + agg_type: Union[Dict[str, Any], Agg[_R], str], + *args: Any, + **params: Any, + ) -> "Pipeline[_R]": + return cast("Pipeline[_R]", self._agg(False, name, agg_type, *args, **params)) def result(self, search: "SearchBase[_R]", data: Any) -> AttrDict[Any]: return BucketData(self, search, data) # type: ignore diff --git a/elasticsearch_dsl/document_base.py b/elasticsearch_dsl/document_base.py index aa378738..23b10c0a 100644 --- a/elasticsearch_dsl/document_base.py +++ b/elasticsearch_dsl/document_base.py @@ -195,9 +195,7 @@ def __init__(self, name: str, bases: Tuple[type, ...], attrs: Dict[str, Any]): field = None field_args: List[Any] = [] field_kwargs: Dict[str, Any] = {} - if not isinstance(type_, type): - raise TypeError(f"Cannot map type {type_}") - elif issubclass(type_, InnerDoc): + if isinstance(type_, type) and issubclass(type_, InnerDoc): # object or nested field field = Nested if multi else Object field_args = [type_] diff --git a/elasticsearch_dsl/faceted_search_base.py b/elasticsearch_dsl/faceted_search_base.py index 2bd2e9b9..b959b05a 100644 --- a/elasticsearch_dsl/faceted_search_base.py +++ b/elasticsearch_dsl/faceted_search_base.py @@ -16,7 +16,19 @@ # under the License. from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generic, + List, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) from typing_extensions import Self @@ -26,10 +38,11 @@ from .utils import _R, AttrDict if TYPE_CHECKING: + from .document_base import DocumentBase from .response.aggs import BucketData from .search_base import SearchBase -FilterValueType = Union[str, datetime] +FilterValueType = Union[str, datetime, Sequence[str]] __all__ = [ "FacetedSearchBase", @@ -51,7 +64,7 @@ class Facet(Generic[_R]): agg_type: str = "" def __init__( - self, metric: Optional[str] = None, metric_sort: str = "desc", **kwargs: Any + self, metric: Optional[Agg[_R]] = None, metric_sort: str = "desc", **kwargs: Any ): self.filter_values = () self._params = kwargs @@ -137,7 +150,9 @@ def add_filter(self, filter_values: List[FilterValueType]) -> Optional[Query]: class RangeFacet(Facet[_R]): agg_type = "range" - def _range_to_dict(self, range: Tuple[Any, Tuple[int, int]]) -> Dict[str, Any]: + def _range_to_dict( + self, range: Tuple[Any, Tuple[Optional[int], Optional[int]]] + ) -> Dict[str, Any]: key, _range = range out: Dict[str, Any] = {"key": key} if _range[0] is not None: @@ -146,7 +161,11 @@ def _range_to_dict(self, range: Tuple[Any, Tuple[int, int]]) -> Dict[str, Any]: out["to"] = _range[1] return out - def __init__(self, ranges: List[Tuple[Any, Tuple[int, int]]], **kwargs: Any): + def __init__( + self, + ranges: Sequence[Tuple[Any, Tuple[Optional[int], Optional[int]]]], + **kwargs: Any, + ): super().__init__(**kwargs) self._params["ranges"] = list(map(self._range_to_dict, ranges)) self._params["keyed"] = False @@ -277,7 +296,7 @@ class FacetedResponse(Response[_R]): _facets: Dict[str, List[Tuple[Any, int, bool]]] @property - def query_string(self) -> Optional[Query]: + def query_string(self) -> Optional[Union[str, Query]]: return self._faceted_search._query @property @@ -334,9 +353,9 @@ def search(self): """ - index = None - doc_types = None - fields: List[str] = [] + index: Optional[str] = None + doc_types: Optional[List[Union[str, Type["DocumentBase"]]]] = None + fields: Sequence[str] = [] facets: Dict[str, Facet[_R]] = {} using = "default" @@ -346,9 +365,9 @@ def search(self) -> "SearchBase[_R]": ... def __init__( self, - query: Optional[Query] = None, + query: Optional[Union[str, Query]] = None, filters: Dict[str, FilterValueType] = {}, - sort: List[str] = [], + sort: Sequence[str] = [], ): """ :arg query: the text to search for @@ -383,16 +402,18 @@ def add_filter( ] # remember the filter values for use in FacetedResponse - self.filter_values[name] = filter_values + self.filter_values[name] = filter_values # type: ignore[assignment] # get the filter from the facet - f = self.facets[name].add_filter(filter_values) + f = self.facets[name].add_filter(filter_values) # type: ignore[arg-type] if f is None: return self._filters[name] = f - def query(self, search: "SearchBase[_R]", query: Query) -> "SearchBase[_R]": + def query( + self, search: "SearchBase[_R]", query: Union[str, Query] + ) -> "SearchBase[_R]": """ Add query part to ``search``. diff --git a/elasticsearch_dsl/field.py b/elasticsearch_dsl/field.py index b1c3afad..7896fe5f 100644 --- a/elasticsearch_dsl/field.py +++ b/elasticsearch_dsl/field.py @@ -225,9 +225,9 @@ def _empty(self) -> "InnerDoc": def _wrap(self, data: Dict[str, Any]) -> "InnerDoc": return self._doc_class.from_es(data, data_only=True) - def empty(self) -> Union["InnerDoc", AttrList]: + def empty(self) -> Union["InnerDoc", AttrList[Any]]: if self._multi: - return AttrList([], self._wrap) + return AttrList[Any]([], self._wrap) return self._empty() def to_dict(self) -> Dict[str, Any]: diff --git a/elasticsearch_dsl/response/aggs.py b/elasticsearch_dsl/response/aggs.py index 9cd61b06..3525e1f9 100644 --- a/elasticsearch_dsl/response/aggs.py +++ b/elasticsearch_dsl/response/aggs.py @@ -52,7 +52,7 @@ def __init__( class BucketData(AggResponse[_R]): _bucket_class = Bucket - _buckets: Union[AttrDict[Any], AttrList] + _buckets: Union[AttrDict[Any], AttrList[Any]] def _wrap_bucket(self, data: Dict[str, Any]) -> Bucket[_R]: return self._bucket_class( @@ -70,11 +70,11 @@ def __len__(self) -> int: def __getitem__(self, key: Any) -> Any: if isinstance(key, (int, slice)): - return cast(AttrList, self.buckets)[key] + return cast(AttrList[Any], self.buckets)[key] return super().__getitem__(key) @property - def buckets(self) -> Union[AttrDict[Any], AttrList]: + def buckets(self) -> Union[AttrDict[Any], AttrList[Any]]: if not hasattr(self, "_buckets"): field = getattr(self._meta["aggs"], "field", None) if field: diff --git a/elasticsearch_dsl/search_base.py b/elasticsearch_dsl/search_base.py index e09e163d..b54bbaec 100644 --- a/elasticsearch_dsl/search_base.py +++ b/elasticsearch_dsl/search_base.py @@ -114,7 +114,7 @@ class ProxyDescriptor(Generic[_S]): def __init__(self, name: str): self._attr_name = f"_{name}_proxy" - def __get__(self, instance: _S, owner: object) -> QueryProxy[_S]: + def __get__(self, instance: Any, owner: object) -> QueryProxy[_S]: return cast(QueryProxy[_S], getattr(instance, self._attr_name)) def __set__(self, instance: _S, value: Dict[str, Any]) -> None: @@ -122,11 +122,11 @@ def __set__(self, instance: _S, value: Dict[str, Any]) -> None: proxy._proxied = Q(value) -class AggsProxy(AggBase, DslBase, Generic[_S]): +class AggsProxy(AggBase[_R], DslBase): name = "aggs" - def __init__(self, search: _S): - self._base = cast("Agg", self) + def __init__(self, search: "SearchBase[_R]"): + self._base = cast("Agg[_R]", self) self._search = search self._params = {"aggs": {}} @@ -193,7 +193,7 @@ def params(self, **kwargs: Any) -> Self: s._params.update(kwargs) return s - def index(self, *index: str) -> Self: + def index(self, *index: Union[str, List[str], Tuple[str, ...]]) -> Self: """ Set the index for the search. If called empty it will remove all information. @@ -350,8 +350,8 @@ def to_dict(self) -> Dict[str, Any]: ... class SearchBase(Request[_R]): - query = ProxyDescriptor["SearchBase[_R]"]("query") - post_filter = ProxyDescriptor["SearchBase[_R]"]("post_filter") + query = ProxyDescriptor[Self]("query") + post_filter = ProxyDescriptor[Self]("post_filter") _response: Response[_R] def __init__(self, **kwargs: Any): @@ -367,7 +367,7 @@ def __init__(self, **kwargs: Any): """ super().__init__(**kwargs) - self.aggs = AggsProxy(self) + self.aggs = AggsProxy[_R](self) self._sort: List[Union[str, Dict[str, Dict[str, str]]]] = [] self._knn: List[Dict[str, Any]] = [] self._rank: Dict[str, Any] = {} @@ -383,10 +383,10 @@ def __init__(self, **kwargs: Any): self._post_filter_proxy = QueryProxy(self, "post_filter") def filter(self, *args: Any, **kwargs: Any) -> Self: - return cast(Self, self.query(Bool(filter=[Q(*args, **kwargs)]))) + return self.query(Bool(filter=[Q(*args, **kwargs)])) def exclude(self, *args: Any, **kwargs: Any) -> Self: - return cast(Self, self.query(Bool(filter=[~Q(*args, **kwargs)]))) + return self.query(Bool(filter=[~Q(*args, **kwargs)])) def __getitem__(self, n: Union[int, slice]) -> Self: """ diff --git a/elasticsearch_dsl/update_by_query_base.py b/elasticsearch_dsl/update_by_query_base.py index d02b4e51..e4490ddf 100644 --- a/elasticsearch_dsl/update_by_query_base.py +++ b/elasticsearch_dsl/update_by_query_base.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, Type, cast +from typing import Any, Dict, Type from typing_extensions import Self @@ -26,7 +26,7 @@ class UpdateByQueryBase(Request[_R]): - query = ProxyDescriptor["UpdateByQueryBase[_R]"]("query") + query = ProxyDescriptor[Self]("query") def __init__(self, **kwargs: Any): """ @@ -46,10 +46,10 @@ def __init__(self, **kwargs: Any): self._query_proxy = QueryProxy(self, "query") def filter(self, *args: Any, **kwargs: Any) -> Self: - return cast(Self, self.query(Bool(filter=[Q(*args, **kwargs)]))) + return self.query(Bool(filter=[Q(*args, **kwargs)])) def exclude(self, *args: Any, **kwargs: Any) -> Self: - return cast(Self, self.query(Bool(filter=[~Q(*args, **kwargs)]))) + return self.query(Bool(filter=[~Q(*args, **kwargs)])) @classmethod def from_dict(cls, d: Dict[str, Any]) -> Self: diff --git a/elasticsearch_dsl/utils.py b/elasticsearch_dsl/utils.py index d3d70b5e..1f3eb6b6 100644 --- a/elasticsearch_dsl/utils.py +++ b/elasticsearch_dsl/utils.py @@ -86,9 +86,9 @@ def _wrap(val: Any, obj_wrapper: Optional[Callable[[Any], Any]] = None) -> Any: return val -class AttrList: +class AttrList(Generic[_ValT]): def __init__( - self, l: List[Any], obj_wrapper: Optional[Callable[[Any], Any]] = None + self, l: List[_ValT], obj_wrapper: Optional[Callable[[_ValT], Any]] = None ): # make iterables into lists if not isinstance(l, list): @@ -111,10 +111,10 @@ def __ne__(self, other: Any) -> bool: def __getitem__(self, k: Union[int, slice]) -> Any: l = self._l_[k] if isinstance(k, slice): - return AttrList(l, obj_wrapper=self._obj_wrapper) + return AttrList[_ValT](l, obj_wrapper=self._obj_wrapper) # type: ignore[arg-type] return _wrap(l, self._obj_wrapper) - def __setitem__(self, k: int, value: Any) -> None: + def __setitem__(self, k: int, value: _ValT) -> None: self._l_[k] = value def __iter__(self) -> Iterator[Any]: @@ -131,15 +131,15 @@ def __nonzero__(self) -> bool: def __getattr__(self, name: str) -> Any: return getattr(self._l_, name) - def __getstate__(self) -> Tuple[List[Any], Optional[Callable[[Any], Any]]]: + def __getstate__(self) -> Tuple[List[_ValT], Optional[Callable[[_ValT], Any]]]: return self._l_, self._obj_wrapper def __setstate__( - self, state: Tuple[List[Any], Optional[Callable[[Any], Any]]] + self, state: Tuple[List[_ValT], Optional[Callable[[_ValT], Any]]] ) -> None: self._l_, self._obj_wrapper = state - def to_list(self) -> List[Any]: + def to_list(self) -> List[_ValT]: return self._l_ @@ -215,7 +215,6 @@ def __delitem__(self, key: str) -> None: del self._d_[key] def __setattr__(self, name: str, value: _ValT) -> None: - print(self._d_) if name in self._d_ or not hasattr(self.__class__, name): self._d_[name] = value else: diff --git a/examples/alias_migration.py b/examples/alias_migration.py index e407990c..c9fe4ede 100644 --- a/examples/alias_migration.py +++ b/examples/alias_migration.py @@ -38,24 +38,30 @@ import os from datetime import datetime from fnmatch import fnmatch +from typing import TYPE_CHECKING, Any, Dict, List, Optional -from elasticsearch_dsl import Date, Document, Keyword, Text, connections +from elasticsearch_dsl import Document, Keyword, connections, mapped_field ALIAS = "test-blog" PATTERN = ALIAS + "-*" class BlogPost(Document): - title = Text() - published = Date() - tags = Keyword(multi=True) - content = Text() + if TYPE_CHECKING: + # definitions here help type checkers understand additional arguments + # that are allowed in the constructor + _id: int - def is_published(self): - return self.published and datetime.now() > self.published + title: str + tags: List[str] = mapped_field(Keyword()) + content: str + published: Optional[datetime] = mapped_field(default=None) + + def is_published(self) -> bool: + return bool(self.published and datetime.now() > self.published) @classmethod - def _matches(cls, hit): + def _matches(cls, hit: Dict[str, Any]) -> bool: # override _matches to match indices in a pattern instead of just ALIAS # hit is the raw dict as returned by elasticsearch return fnmatch(hit["_index"], PATTERN) @@ -68,7 +74,7 @@ class Index: settings = {"number_of_shards": 1, "number_of_replicas": 0} -def setup(): +def setup() -> None: """ Create the index template in elasticsearch specifying the mappings and any settings to be used. This can be run at any time, ideally at every new code @@ -85,7 +91,7 @@ def setup(): migrate(move_data=False) -def migrate(move_data=True, update_alias=True): +def migrate(move_data: bool = True, update_alias: bool = True) -> None: """ Upgrade function that creates a new index for the data. Optionally it also can (and by default will) reindex previous copy of the data into the new index @@ -125,7 +131,7 @@ def migrate(move_data=True, update_alias=True): ) -def main(): +def main() -> None: # initiate the default connection to elasticsearch connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]]) diff --git a/examples/async/alias_migration.py b/examples/async/alias_migration.py index 07bb995a..bede9098 100644 --- a/examples/async/alias_migration.py +++ b/examples/async/alias_migration.py @@ -39,24 +39,30 @@ import os from datetime import datetime from fnmatch import fnmatch +from typing import TYPE_CHECKING, Any, Dict, List, Optional -from elasticsearch_dsl import AsyncDocument, Date, Keyword, Text, async_connections +from elasticsearch_dsl import AsyncDocument, Keyword, async_connections, mapped_field ALIAS = "test-blog" PATTERN = ALIAS + "-*" class BlogPost(AsyncDocument): - title = Text() - published = Date() - tags = Keyword(multi=True) - content = Text() + if TYPE_CHECKING: + # definitions here help type checkers understand additional arguments + # that are allowed in the constructor + _id: int - def is_published(self): - return self.published and datetime.now() > self.published + title: str + tags: List[str] = mapped_field(Keyword()) + content: str + published: Optional[datetime] = mapped_field(default=None) + + def is_published(self) -> bool: + return bool(self.published and datetime.now() > self.published) @classmethod - def _matches(cls, hit): + def _matches(cls, hit: Dict[str, Any]) -> bool: # override _matches to match indices in a pattern instead of just ALIAS # hit is the raw dict as returned by elasticsearch return fnmatch(hit["_index"], PATTERN) @@ -69,7 +75,7 @@ class Index: settings = {"number_of_shards": 1, "number_of_replicas": 0} -async def setup(): +async def setup() -> None: """ Create the index template in elasticsearch specifying the mappings and any settings to be used. This can be run at any time, ideally at every new code @@ -86,7 +92,7 @@ async def setup(): await migrate(move_data=False) -async def migrate(move_data=True, update_alias=True): +async def migrate(move_data: bool = True, update_alias: bool = True) -> None: """ Upgrade function that creates a new index for the data. Optionally it also can (and by default will) reindex previous copy of the data into the new index @@ -126,7 +132,7 @@ async def migrate(move_data=True, update_alias=True): ) -async def main(): +async def main() -> None: # initiate the default connection to elasticsearch async_connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]]) diff --git a/examples/async/completion.py b/examples/async/completion.py index cbd6b3e1..a7a7a79e 100644 --- a/examples/async/completion.py +++ b/examples/async/completion.py @@ -29,6 +29,7 @@ import asyncio import os from itertools import permutations +from typing import TYPE_CHECKING, Any, Dict, Optional from elasticsearch_dsl import ( AsyncDocument, @@ -38,6 +39,7 @@ Text, analyzer, async_connections, + mapped_field, token_filter, ) @@ -51,13 +53,18 @@ class Person(AsyncDocument): - name = Text(fields={"keyword": Keyword()}) - popularity = Long() + if TYPE_CHECKING: + # definitions here help type checkers understand additional arguments + # that are allowed in the constructor + _id: Optional[int] = mapped_field(default=None) + + name: str = mapped_field(Text(fields={"keyword": Keyword()}), default="") + popularity: int = mapped_field(Long(), default=0) # completion field with a custom analyzer - suggest = Completion(analyzer=ascii_fold) + suggest: Dict[str, Any] = mapped_field(Completion(analyzer=ascii_fold), init=False) - def clean(self): + def clean(self) -> None: """ Automatically construct the suggestion input and weight by taking all possible permutations of Person's name as ``input`` and taking their @@ -73,7 +80,7 @@ class Index: settings = {"number_of_shards": 1, "number_of_replicas": 0} -async def main(): +async def main() -> None: # initiate the default connection to elasticsearch async_connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]]) diff --git a/examples/async/composite_agg.py b/examples/async/composite_agg.py index 52726cb4..0b27c10b 100644 --- a/examples/async/composite_agg.py +++ b/examples/async/composite_agg.py @@ -17,22 +17,30 @@ import asyncio import os +from typing import Any, AsyncIterator, Dict, List, Optional, Union -from elasticsearch_dsl import A, AsyncSearch, async_connections +from elasticsearch_dsl import A, Agg, AsyncSearch, Response, async_connections -async def scan_aggs(search, source_aggs, inner_aggs={}, size=10): +async def scan_aggs( + search: AsyncSearch, + source_aggs: Union[Dict[str, Agg], List[Dict[str, Agg]]], + inner_aggs: Dict[str, Agg] = {}, + size: Optional[int] = 10, +) -> AsyncIterator[Response]: """ Helper function used to iterate over all possible bucket combinations of ``source_aggs``, returning results of ``inner_aggs`` for each. Uses the ``composite`` aggregation under the hood to perform this. """ - async def run_search(**kwargs): + async def run_search(**kwargs: Any) -> Response: s = search[:0] - s.aggs.bucket("comp", "composite", sources=source_aggs, size=size, **kwargs) + bucket = s.aggs.bucket( + "comp", "composite", sources=source_aggs, size=size, **kwargs + ) for agg_name, agg in inner_aggs.items(): - s.aggs["comp"][agg_name] = agg + bucket[agg_name] = agg return await s.execute() response = await run_search() @@ -46,7 +54,7 @@ async def run_search(**kwargs): response = await run_search(after=after) -async def main(): +async def main() -> None: # initiate the default connection to elasticsearch async_connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]]) diff --git a/examples/async/parent_child.py b/examples/async/parent_child.py index 822ce86e..2a74d2d4 100644 --- a/examples/async/parent_child.py +++ b/examples/async/parent_child.py @@ -42,19 +42,20 @@ import asyncio import os from datetime import datetime +from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast from elasticsearch_dsl import ( AsyncDocument, - Boolean, + AsyncIndex, + AsyncSearch, Date, InnerDoc, Join, Keyword, Long, - Nested, - Object, Text, async_connections, + mapped_field, ) @@ -63,11 +64,11 @@ class User(InnerDoc): Class used to represent a denormalized user stored on other objects. """ - id = Long(required=True) - signed_up = Date() - username = Text(fields={"keyword": Keyword()}, required=True) - email = Text(fields={"keyword": Keyword()}) - location = Text(fields={"keyword": Keyword()}) + id: int = mapped_field(Long()) + signed_up: Optional[datetime] = mapped_field(Date()) + username: str = mapped_field(Text(fields={"keyword": Keyword()})) + email: Optional[str] = mapped_field(Text(fields={"keyword": Keyword()})) + location: Optional[str] = mapped_field(Text(fields={"keyword": Keyword()})) class Comment(InnerDoc): @@ -75,9 +76,9 @@ class Comment(InnerDoc): Class wrapper for nested comment objects. """ - author = Object(User, required=True) - created = Date(required=True) - content = Text(required=True) + author: User + created: datetime + content: str class Post(AsyncDocument): @@ -85,14 +86,24 @@ class Post(AsyncDocument): Base class for Question and Answer containing the common fields. """ - author = Object(User, required=True) - created = Date(required=True) - body = Text(required=True) - comments = Nested(Comment) - question_answer = Join(relations={"question": "answer"}) + author: User + + if TYPE_CHECKING: + # definitions here help type checkers understand additional arguments + # that are allowed in the constructor + _routing: str = mapped_field(default=None) + _index: AsyncIndex = mapped_field(default=None) + _id: Optional[int] = mapped_field(default=None) + + created: Optional[datetime] = mapped_field(default=None) + body: str = mapped_field(default="") + comments: List[Comment] = mapped_field(default_factory=list) + question_answer: Any = mapped_field( + Join(relations={"question": "answer"}), default_factory=dict + ) @classmethod - def _matches(cls, hit): + def _matches(cls, hit: Dict[str, Any]) -> bool: # Post is an abstract class, make sure it never gets used for # deserialization return False @@ -104,35 +115,49 @@ class Index: "number_of_replicas": 0, } - async def add_comment(self, user, content, created=None, commit=True): + async def add_comment( + self, + user: User, + content: str, + created: Optional[datetime] = None, + commit: Optional[bool] = True, + ) -> Comment: c = Comment(author=user, content=content, created=created or datetime.now()) self.comments.append(c) if commit: await self.save() return c - async def save(self, **kwargs): + async def save(self, **kwargs: Any) -> None: # type: ignore[override] # if there is no date, use now if self.created is None: self.created = datetime.now() - return await super().save(**kwargs) + await super().save(**kwargs) class Question(Post): - # use multi True so that .tags will return empty list if not present - tags = Keyword(multi=True) - title = Text(fields={"keyword": Keyword()}) + tags: List[str] = mapped_field( + default_factory=list + ) # .tags will return empty list if not present + title: str = mapped_field(Text(fields={"keyword": Keyword()}), default="") @classmethod - def _matches(cls, hit): + def _matches(cls, hit: Dict[str, Any]) -> bool: """Use Question class for parent documents""" - return hit["_source"]["question_answer"] == "question" + return bool(hit["_source"]["question_answer"] == "question") @classmethod - def search(cls, **kwargs): + def search(cls, **kwargs: Any) -> AsyncSearch: # type: ignore[override] return cls._index.search(**kwargs).filter("term", question_answer="question") - async def add_answer(self, user, body, created=None, accepted=False, commit=True): + async def add_answer( + self, + user: User, + body: str, + created: Optional[datetime] = None, + accepted: bool = False, + commit: Optional[bool] = True, + ) -> "Answer": answer = Answer( # required make sure the answer is stored in the same shard _routing=self.meta.id, @@ -144,13 +169,13 @@ async def add_answer(self, user, body, created=None, accepted=False, commit=True author=user, created=created, body=body, - accepted=accepted, + is_accepted=accepted, ) if commit: await answer.save() return answer - def search_answers(self): + def search_answers(self) -> AsyncSearch: # search only our index s = Answer.search() # filter for answers belonging to us @@ -159,25 +184,25 @@ def search_answers(self): s = s.params(routing=self.meta.id) return s - async def get_answers(self): + async def get_answers(self) -> List[Any]: """ Get answers either from inner_hits already present or by searching elasticsearch. """ if "inner_hits" in self.meta and "answer" in self.meta.inner_hits: - return self.meta.inner_hits.answer.hits + return cast(List[Any], self.meta.inner_hits.answer.hits) return [a async for a in self.search_answers()] - async def save(self, **kwargs): + async def save(self, **kwargs: Any) -> None: # type: ignore[override] self.question_answer = "question" - return await super().save(**kwargs) + await super().save(**kwargs) class Answer(Post): - is_accepted = Boolean() + is_accepted: bool = mapped_field(default=False) @classmethod - def _matches(cls, hit): + def _matches(cls, hit: Dict[str, Any]) -> bool: """Use Answer class for child documents with child name 'answer'""" return ( isinstance(hit["_source"]["question_answer"], dict) @@ -185,31 +210,31 @@ def _matches(cls, hit): ) @classmethod - def search(cls, **kwargs): + def search(cls, **kwargs: Any) -> AsyncSearch: # type: ignore[override] return cls._index.search(**kwargs).exclude("term", question_answer="question") - async def get_question(self): + async def get_question(self) -> Optional[Question]: # cache question in self.meta # any attributes set on self would be interpreted as fields if "question" not in self.meta: self.meta.question = await Question.get( id=self.question_answer.parent, index=self.meta.index ) - return self.meta.question + return cast(Optional[Question], self.meta.question) - async def save(self, **kwargs): + async def save(self, **kwargs: Any) -> None: # type: ignore[override] # set routing to parents id automatically self.meta.routing = self.question_answer.parent - return await super().save(**kwargs) + await super().save(**kwargs) -async def setup(): +async def setup() -> None: """Create an IndexTemplate and save it into elasticsearch.""" index_template = Post._index.as_template("base") await index_template.save() -async def main(): +async def main() -> Answer: # initiate the default connection to elasticsearch async_connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]]) diff --git a/examples/async/percolate.py b/examples/async/percolate.py index 4b075cd8..efa83d4d 100644 --- a/examples/async/percolate.py +++ b/examples/async/percolate.py @@ -17,6 +17,7 @@ import asyncio import os +from typing import TYPE_CHECKING, Any, List, Optional from elasticsearch_dsl import ( AsyncDocument, @@ -24,8 +25,9 @@ Keyword, Percolator, Q, - Text, + Query, async_connections, + mapped_field, ) @@ -34,13 +36,18 @@ class BlogPost(AsyncDocument): Blog posts that will be automatically tagged based on percolation queries. """ - content = Text() - tags = Keyword(multi=True) + if TYPE_CHECKING: + # definitions here help type checkers understand additional arguments + # that are allowed in the constructor + _id: int + + content: Optional[str] + tags: List[str] = mapped_field(Keyword(), default_factory=list) class Index: name = "test-blogpost" - async def add_tags(self): + async def add_tags(self) -> None: # run a percolation to automatically tag the blog post. s = AsyncSearch(index="test-percolator") s = s.query( @@ -54,9 +61,9 @@ async def add_tags(self): # make sure tags are unique self.tags = list(set(self.tags)) - async def save(self, **kwargs): + async def save(self, **kwargs: Any) -> None: # type: ignore[override] await self.add_tags() - return await super().save(**kwargs) + await super().save(**kwargs) class PercolatorDoc(AsyncDocument): @@ -64,22 +71,25 @@ class PercolatorDoc(AsyncDocument): Document class used for storing the percolation queries. """ + if TYPE_CHECKING: + _id: str + # relevant fields from BlogPost must be also present here for the queries # to be able to use them. Another option would be to use document # inheritance but save() would have to be reset to normal behavior. - content = Text() + content: Optional[str] # the percolator query to be run against the doc - query = Percolator() + query: Query = mapped_field(Percolator()) # list of tags to append to a document - tags = Keyword(multi=True) + tags: List[str] = mapped_field(Keyword(multi=True)) class Index: name = "test-percolator" settings = {"number_of_shards": 1, "number_of_replicas": 0} -async def setup(): +async def setup() -> None: # create the percolator index if it doesn't exist if not await PercolatorDoc._index.exists(): await PercolatorDoc.init() @@ -88,11 +98,12 @@ async def setup(): await PercolatorDoc( _id="python", tags=["programming", "development", "python"], + content="", query=Q("match", content="python"), ).save(refresh=True) -async def main(): +async def main() -> None: # initiate the default connection to elasticsearch async_connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]]) diff --git a/examples/async/search_as_you_type.py b/examples/async/search_as_you_type.py index 3b76622e..67830b20 100644 --- a/examples/async/search_as_you_type.py +++ b/examples/async/search_as_you_type.py @@ -27,12 +27,14 @@ import asyncio import os +from typing import TYPE_CHECKING, Optional from elasticsearch_dsl import ( AsyncDocument, SearchAsYouType, analyzer, async_connections, + mapped_field, token_filter, ) from elasticsearch_dsl.query import MultiMatch @@ -47,14 +49,19 @@ class Person(AsyncDocument): - name = SearchAsYouType(max_shingle_size=3) + if TYPE_CHECKING: + # definitions here help type checkers understand additional arguments + # that are allowed in the constructor + _id: Optional[int] = mapped_field(default=None) + + name: str = mapped_field(SearchAsYouType(max_shingle_size=3), default="") class Index: name = "test-search-as-you-type" settings = {"number_of_shards": 1, "number_of_replicas": 0} -async def main(): +async def main() -> None: # initiate the default connection to elasticsearch async_connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]]) @@ -82,7 +89,7 @@ async def main(): for text in ("já", "Cimr", "toulouse", "Henri Tou", "a"): s = Person.search() - s.query = MultiMatch( + s.query = MultiMatch( # type: ignore[assignment] query=text, type="bool_prefix", fields=["name", "name._2gram", "name._3gram"], diff --git a/examples/async/sparse_vectors.py b/examples/async/sparse_vectors.py index 773cff80..d50b4080 100644 --- a/examples/async/sparse_vectors.py +++ b/examples/async/sparse_vectors.py @@ -63,21 +63,22 @@ import asyncio import json import os +from datetime import datetime +from typing import Any, Dict, List, Optional from urllib.request import urlopen -import nltk +import nltk # type: ignore from tqdm import tqdm from elasticsearch_dsl import ( AsyncDocument, - Date, + AsyncSearch, InnerDoc, Keyword, - Nested, Q, SparseVector, - Text, async_connections, + mapped_field, ) DATASET_URL = "https://raw.githubusercontent.com/elastic/elasticsearch-labs/main/datasets/workplace-documents.json" @@ -87,8 +88,8 @@ class Passage(InnerDoc): - content = Text() - embedding = SparseVector() + content: Optional[str] + embedding: Dict[str, float] = mapped_field(SparseVector(), init=False) class WorkplaceDoc(AsyncDocument): @@ -96,18 +97,18 @@ class Index: name = "workplace_documents_sparse" settings = {"default_pipeline": "elser_ingest_pipeline"} - name = Text() - summary = Text() - content = Text() - created = Date() - updated = Date() - url = Keyword() - category = Keyword() - passages = Nested(Passage) + name: str + summary: str + content: str + created: datetime + updated: Optional[datetime] + url: str = mapped_field(Keyword()) + category: str = mapped_field(Keyword()) + passages: List[Passage] = mapped_field(default=[]) - _model = None + _model: Any = None - def clean(self): + def clean(self) -> None: # split the content into sentences passages = nltk.sent_tokenize(self.content) @@ -116,7 +117,7 @@ def clean(self): self.passages.append(Passage(content=passage)) -async def create(): +async def create() -> None: # create the index await WorkplaceDoc._index.delete(ignore_unavailable=True) @@ -139,7 +140,7 @@ async def create(): await doc.save() -async def search(query): +async def search(query: str) -> AsyncSearch[WorkplaceDoc]: return WorkplaceDoc.search()[:5].query( "nested", path="passages", @@ -154,7 +155,7 @@ async def search(query): ) -def parse_args(): +def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Vector database with Elasticsearch") parser.add_argument( "--recreate-index", action="store_true", help="Recreate and populate the index" @@ -168,7 +169,7 @@ def parse_args(): return parser.parse_args() -async def main(): +async def main() -> None: args = parse_args() # initiate the default connection to elasticsearch diff --git a/examples/async/vectors.py b/examples/async/vectors.py index 2d84b516..5221929d 100644 --- a/examples/async/vectors.py +++ b/examples/async/vectors.py @@ -48,7 +48,7 @@ import json import os from datetime import datetime -from typing import List, Optional, cast +from typing import Any, List, Optional, cast from urllib.request import urlopen import nltk # type: ignore @@ -72,32 +72,34 @@ # initialize sentence tokenizer nltk.download("punkt", quiet=True) +# this will be the embedding model +embedding_model: Any = None + class Passage(InnerDoc): - content: M[str] - embedding: M[List[float]] = mapped_field(DenseVector()) + content: str + embedding: List[float] = mapped_field(DenseVector()) class WorkplaceDoc(AsyncDocument): class Index: name = "workplace_documents" - name: M[str] - summary: M[str] - content: M[str] - created: M[datetime] - updated: M[Optional[datetime]] - url: M[str] = mapped_field(Keyword(required=True)) - category: M[str] = mapped_field(Keyword(required=True)) + name: str + summary: str + content: str + created: datetime + updated: Optional[datetime] + url: str = mapped_field(Keyword(required=True)) + category: str = mapped_field(Keyword(required=True)) passages: M[List[Passage]] = mapped_field(default=[]) - _model = None - @classmethod def get_embedding(cls, input: str) -> List[float]: - if cls._model is None: - cls._model = SentenceTransformer(MODEL_NAME) - return cast(List[float], list(cls._model.encode(input))) + global embedding_model + if embedding_model is None: + embedding_model = SentenceTransformer(MODEL_NAME) + return cast(List[float], list(embedding_model.encode(input))) def clean(self) -> None: # split the content into sentences diff --git a/examples/completion.py b/examples/completion.py index 888eee62..81d1b0e4 100644 --- a/examples/completion.py +++ b/examples/completion.py @@ -28,6 +28,7 @@ import os from itertools import permutations +from typing import TYPE_CHECKING, Any, Dict, Optional from elasticsearch_dsl import ( Completion, @@ -37,6 +38,7 @@ Text, analyzer, connections, + mapped_field, token_filter, ) @@ -50,13 +52,18 @@ class Person(Document): - name = Text(fields={"keyword": Keyword()}) - popularity = Long() + if TYPE_CHECKING: + # definitions here help type checkers understand additional arguments + # that are allowed in the constructor + _id: Optional[int] = mapped_field(default=None) + + name: str = mapped_field(Text(fields={"keyword": Keyword()}), default="") + popularity: int = mapped_field(Long(), default=0) # completion field with a custom analyzer - suggest = Completion(analyzer=ascii_fold) + suggest: Dict[str, Any] = mapped_field(Completion(analyzer=ascii_fold), init=False) - def clean(self): + def clean(self) -> None: """ Automatically construct the suggestion input and weight by taking all possible permutations of Person's name as ``input`` and taking their @@ -72,7 +79,7 @@ class Index: settings = {"number_of_shards": 1, "number_of_replicas": 0} -def main(): +def main() -> None: # initiate the default connection to elasticsearch connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]]) diff --git a/examples/composite_agg.py b/examples/composite_agg.py index 753a9cba..a0b992dd 100644 --- a/examples/composite_agg.py +++ b/examples/composite_agg.py @@ -16,22 +16,30 @@ # under the License. import os +from typing import Any, Dict, Iterator, List, Optional, Union -from elasticsearch_dsl import A, Search, connections +from elasticsearch_dsl import A, Agg, Response, Search, connections -def scan_aggs(search, source_aggs, inner_aggs={}, size=10): +def scan_aggs( + search: Search, + source_aggs: Union[Dict[str, Agg], List[Dict[str, Agg]]], + inner_aggs: Dict[str, Agg] = {}, + size: Optional[int] = 10, +) -> Iterator[Response]: """ Helper function used to iterate over all possible bucket combinations of ``source_aggs``, returning results of ``inner_aggs`` for each. Uses the ``composite`` aggregation under the hood to perform this. """ - def run_search(**kwargs): + def run_search(**kwargs: Any) -> Response: s = search[:0] - s.aggs.bucket("comp", "composite", sources=source_aggs, size=size, **kwargs) + bucket = s.aggs.bucket( + "comp", "composite", sources=source_aggs, size=size, **kwargs + ) for agg_name, agg in inner_aggs.items(): - s.aggs["comp"][agg_name] = agg + bucket[agg_name] = agg return s.execute() response = run_search() @@ -45,7 +53,7 @@ def run_search(**kwargs): response = run_search(after=after) -def main(): +def main() -> None: # initiate the default connection to elasticsearch connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]]) diff --git a/examples/parent_child.py b/examples/parent_child.py index a4efdb3e..6d20dde2 100644 --- a/examples/parent_child.py +++ b/examples/parent_child.py @@ -41,19 +41,20 @@ """ import os from datetime import datetime +from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast from elasticsearch_dsl import ( - Boolean, Date, Document, + Index, InnerDoc, Join, Keyword, Long, - Nested, - Object, + Search, Text, connections, + mapped_field, ) @@ -62,11 +63,11 @@ class User(InnerDoc): Class used to represent a denormalized user stored on other objects. """ - id = Long(required=True) - signed_up = Date() - username = Text(fields={"keyword": Keyword()}, required=True) - email = Text(fields={"keyword": Keyword()}) - location = Text(fields={"keyword": Keyword()}) + id: int = mapped_field(Long()) + signed_up: Optional[datetime] = mapped_field(Date()) + username: str = mapped_field(Text(fields={"keyword": Keyword()})) + email: Optional[str] = mapped_field(Text(fields={"keyword": Keyword()})) + location: Optional[str] = mapped_field(Text(fields={"keyword": Keyword()})) class Comment(InnerDoc): @@ -74,9 +75,9 @@ class Comment(InnerDoc): Class wrapper for nested comment objects. """ - author = Object(User, required=True) - created = Date(required=True) - content = Text(required=True) + author: User + created: datetime + content: str class Post(Document): @@ -84,14 +85,24 @@ class Post(Document): Base class for Question and Answer containing the common fields. """ - author = Object(User, required=True) - created = Date(required=True) - body = Text(required=True) - comments = Nested(Comment) - question_answer = Join(relations={"question": "answer"}) + author: User + + if TYPE_CHECKING: + # definitions here help type checkers understand additional arguments + # that are allowed in the constructor + _routing: str = mapped_field(default=None) + _index: Index = mapped_field(default=None) + _id: Optional[int] = mapped_field(default=None) + + created: Optional[datetime] = mapped_field(default=None) + body: str = mapped_field(default="") + comments: List[Comment] = mapped_field(default_factory=list) + question_answer: Any = mapped_field( + Join(relations={"question": "answer"}), default_factory=dict + ) @classmethod - def _matches(cls, hit): + def _matches(cls, hit: Dict[str, Any]) -> bool: # Post is an abstract class, make sure it never gets used for # deserialization return False @@ -103,35 +114,49 @@ class Index: "number_of_replicas": 0, } - def add_comment(self, user, content, created=None, commit=True): + def add_comment( + self, + user: User, + content: str, + created: Optional[datetime] = None, + commit: Optional[bool] = True, + ) -> Comment: c = Comment(author=user, content=content, created=created or datetime.now()) self.comments.append(c) if commit: self.save() return c - def save(self, **kwargs): + def save(self, **kwargs: Any) -> None: # type: ignore[override] # if there is no date, use now if self.created is None: self.created = datetime.now() - return super().save(**kwargs) + super().save(**kwargs) class Question(Post): - # use multi True so that .tags will return empty list if not present - tags = Keyword(multi=True) - title = Text(fields={"keyword": Keyword()}) + tags: List[str] = mapped_field( + default_factory=list + ) # .tags will return empty list if not present + title: str = mapped_field(Text(fields={"keyword": Keyword()}), default="") @classmethod - def _matches(cls, hit): + def _matches(cls, hit: Dict[str, Any]) -> bool: """Use Question class for parent documents""" - return hit["_source"]["question_answer"] == "question" + return bool(hit["_source"]["question_answer"] == "question") @classmethod - def search(cls, **kwargs): + def search(cls, **kwargs: Any) -> Search: # type: ignore[override] return cls._index.search(**kwargs).filter("term", question_answer="question") - def add_answer(self, user, body, created=None, accepted=False, commit=True): + def add_answer( + self, + user: User, + body: str, + created: Optional[datetime] = None, + accepted: bool = False, + commit: Optional[bool] = True, + ) -> "Answer": answer = Answer( # required make sure the answer is stored in the same shard _routing=self.meta.id, @@ -143,13 +168,13 @@ def add_answer(self, user, body, created=None, accepted=False, commit=True): author=user, created=created, body=body, - accepted=accepted, + is_accepted=accepted, ) if commit: answer.save() return answer - def search_answers(self): + def search_answers(self) -> Search: # search only our index s = Answer.search() # filter for answers belonging to us @@ -158,25 +183,25 @@ def search_answers(self): s = s.params(routing=self.meta.id) return s - def get_answers(self): + def get_answers(self) -> List[Any]: """ Get answers either from inner_hits already present or by searching elasticsearch. """ if "inner_hits" in self.meta and "answer" in self.meta.inner_hits: - return self.meta.inner_hits.answer.hits + return cast(List[Any], self.meta.inner_hits.answer.hits) return [a for a in self.search_answers()] - def save(self, **kwargs): + def save(self, **kwargs: Any) -> None: # type: ignore[override] self.question_answer = "question" - return super().save(**kwargs) + super().save(**kwargs) class Answer(Post): - is_accepted = Boolean() + is_accepted: bool = mapped_field(default=False) @classmethod - def _matches(cls, hit): + def _matches(cls, hit: Dict[str, Any]) -> bool: """Use Answer class for child documents with child name 'answer'""" return ( isinstance(hit["_source"]["question_answer"], dict) @@ -184,31 +209,31 @@ def _matches(cls, hit): ) @classmethod - def search(cls, **kwargs): + def search(cls, **kwargs: Any) -> Search: # type: ignore[override] return cls._index.search(**kwargs).exclude("term", question_answer="question") - def get_question(self): + def get_question(self) -> Optional[Question]: # cache question in self.meta # any attributes set on self would be interpreted as fields if "question" not in self.meta: self.meta.question = Question.get( id=self.question_answer.parent, index=self.meta.index ) - return self.meta.question + return cast(Optional[Question], self.meta.question) - def save(self, **kwargs): + def save(self, **kwargs: Any) -> None: # type: ignore[override] # set routing to parents id automatically self.meta.routing = self.question_answer.parent - return super().save(**kwargs) + super().save(**kwargs) -def setup(): +def setup() -> None: """Create an IndexTemplate and save it into elasticsearch.""" index_template = Post._index.as_template("base") index_template.save() -def main(): +def main() -> Answer: # initiate the default connection to elasticsearch connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]]) diff --git a/examples/percolate.py b/examples/percolate.py index df470954..06d58114 100644 --- a/examples/percolate.py +++ b/examples/percolate.py @@ -16,15 +16,17 @@ # under the License. import os +from typing import TYPE_CHECKING, Any, List, Optional from elasticsearch_dsl import ( Document, Keyword, Percolator, Q, + Query, Search, - Text, connections, + mapped_field, ) @@ -33,13 +35,18 @@ class BlogPost(Document): Blog posts that will be automatically tagged based on percolation queries. """ - content = Text() - tags = Keyword(multi=True) + if TYPE_CHECKING: + # definitions here help type checkers understand additional arguments + # that are allowed in the constructor + _id: int + + content: Optional[str] + tags: List[str] = mapped_field(Keyword(), default_factory=list) class Index: name = "test-blogpost" - def add_tags(self): + def add_tags(self) -> None: # run a percolation to automatically tag the blog post. s = Search(index="test-percolator") s = s.query( @@ -53,9 +60,9 @@ def add_tags(self): # make sure tags are unique self.tags = list(set(self.tags)) - def save(self, **kwargs): + def save(self, **kwargs: Any) -> None: # type: ignore[override] self.add_tags() - return super().save(**kwargs) + super().save(**kwargs) class PercolatorDoc(Document): @@ -63,22 +70,25 @@ class PercolatorDoc(Document): Document class used for storing the percolation queries. """ + if TYPE_CHECKING: + _id: str + # relevant fields from BlogPost must be also present here for the queries # to be able to use them. Another option would be to use document # inheritance but save() would have to be reset to normal behavior. - content = Text() + content: Optional[str] # the percolator query to be run against the doc - query = Percolator() + query: Query = mapped_field(Percolator()) # list of tags to append to a document - tags = Keyword(multi=True) + tags: List[str] = mapped_field(Keyword(multi=True)) class Index: name = "test-percolator" settings = {"number_of_shards": 1, "number_of_replicas": 0} -def setup(): +def setup() -> None: # create the percolator index if it doesn't exist if not PercolatorDoc._index.exists(): PercolatorDoc.init() @@ -87,11 +97,12 @@ def setup(): PercolatorDoc( _id="python", tags=["programming", "development", "python"], + content="", query=Q("match", content="python"), ).save(refresh=True) -def main(): +def main() -> None: # initiate the default connection to elasticsearch connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]]) diff --git a/examples/search_as_you_type.py b/examples/search_as_you_type.py index 31ff3f02..e22b11da 100644 --- a/examples/search_as_you_type.py +++ b/examples/search_as_you_type.py @@ -26,12 +26,14 @@ """ import os +from typing import TYPE_CHECKING, Optional from elasticsearch_dsl import ( Document, SearchAsYouType, analyzer, connections, + mapped_field, token_filter, ) from elasticsearch_dsl.query import MultiMatch @@ -46,14 +48,19 @@ class Person(Document): - name = SearchAsYouType(max_shingle_size=3) + if TYPE_CHECKING: + # definitions here help type checkers understand additional arguments + # that are allowed in the constructor + _id: Optional[int] = mapped_field(default=None) + + name: str = mapped_field(SearchAsYouType(max_shingle_size=3), default="") class Index: name = "test-search-as-you-type" settings = {"number_of_shards": 1, "number_of_replicas": 0} -def main(): +def main() -> None: # initiate the default connection to elasticsearch connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]]) @@ -81,7 +88,7 @@ def main(): for text in ("já", "Cimr", "toulouse", "Henri Tou", "a"): s = Person.search() - s.query = MultiMatch( + s.query = MultiMatch( # type: ignore[assignment] query=text, type="bool_prefix", fields=["name", "name._2gram", "name._3gram"], diff --git a/examples/sparse_vectors.py b/examples/sparse_vectors.py index 1de1a241..ae156fe7 100644 --- a/examples/sparse_vectors.py +++ b/examples/sparse_vectors.py @@ -62,21 +62,22 @@ import argparse import json import os +from datetime import datetime +from typing import Any, Dict, List, Optional from urllib.request import urlopen -import nltk +import nltk # type: ignore from tqdm import tqdm from elasticsearch_dsl import ( - Date, Document, InnerDoc, Keyword, - Nested, Q, + Search, SparseVector, - Text, connections, + mapped_field, ) DATASET_URL = "https://raw.githubusercontent.com/elastic/elasticsearch-labs/main/datasets/workplace-documents.json" @@ -86,8 +87,8 @@ class Passage(InnerDoc): - content = Text() - embedding = SparseVector() + content: Optional[str] + embedding: Dict[str, float] = mapped_field(SparseVector(), init=False) class WorkplaceDoc(Document): @@ -95,18 +96,18 @@ class Index: name = "workplace_documents_sparse" settings = {"default_pipeline": "elser_ingest_pipeline"} - name = Text() - summary = Text() - content = Text() - created = Date() - updated = Date() - url = Keyword() - category = Keyword() - passages = Nested(Passage) + name: str + summary: str + content: str + created: datetime + updated: Optional[datetime] + url: str = mapped_field(Keyword()) + category: str = mapped_field(Keyword()) + passages: List[Passage] = mapped_field(default=[]) - _model = None + _model: Any = None - def clean(self): + def clean(self) -> None: # split the content into sentences passages = nltk.sent_tokenize(self.content) @@ -115,7 +116,7 @@ def clean(self): self.passages.append(Passage(content=passage)) -def create(): +def create() -> None: # create the index WorkplaceDoc._index.delete(ignore_unavailable=True) @@ -138,7 +139,7 @@ def create(): doc.save() -def search(query): +def search(query: str) -> Search[WorkplaceDoc]: return WorkplaceDoc.search()[:5].query( "nested", path="passages", @@ -153,7 +154,7 @@ def search(query): ) -def parse_args(): +def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Vector database with Elasticsearch") parser.add_argument( "--recreate-index", action="store_true", help="Recreate and populate the index" @@ -167,7 +168,7 @@ def parse_args(): return parser.parse_args() -def main(): +def main() -> None: args = parse_args() # initiate the default connection to elasticsearch diff --git a/examples/vectors.py b/examples/vectors.py index 7b3eea8e..c983514d 100644 --- a/examples/vectors.py +++ b/examples/vectors.py @@ -47,7 +47,7 @@ import json import os from datetime import datetime -from typing import List, Optional, cast +from typing import Any, List, Optional, cast from urllib.request import urlopen import nltk # type: ignore @@ -71,32 +71,34 @@ # initialize sentence tokenizer nltk.download("punkt", quiet=True) +# this will be the embedding model +embedding_model: Any = None + class Passage(InnerDoc): - content: M[str] - embedding: M[List[float]] = mapped_field(DenseVector()) + content: str + embedding: List[float] = mapped_field(DenseVector()) class WorkplaceDoc(Document): class Index: name = "workplace_documents" - name: M[str] - summary: M[str] - content: M[str] - created: M[datetime] - updated: M[Optional[datetime]] - url: M[str] = mapped_field(Keyword(required=True)) - category: M[str] = mapped_field(Keyword(required=True)) + name: str + summary: str + content: str + created: datetime + updated: Optional[datetime] + url: str = mapped_field(Keyword(required=True)) + category: str = mapped_field(Keyword(required=True)) passages: M[List[Passage]] = mapped_field(default=[]) - _model = None - @classmethod def get_embedding(cls, input: str) -> List[float]: - if cls._model is None: - cls._model = SentenceTransformer(MODEL_NAME) - return cast(List[float], list(cls._model.encode(input))) + global embedding_model + if embedding_model is None: + embedding_model = SentenceTransformer(MODEL_NAME) + return cast(List[float], list(embedding_model.encode(input))) def clean(self) -> None: # split the content into sentences diff --git a/noxfile.py b/noxfile.py index d8d48f7c..0f941e9b 100644 --- a/noxfile.py +++ b/noxfile.py @@ -15,8 +15,6 @@ # specific language governing permissions and limitations # under the License. -import subprocess - import nox SOURCE_FILES = ( @@ -29,19 +27,6 @@ "utils/", ) -TYPED_FILES = ( - # elasticsearch_dsl files are all assumed typed so they are omitted here - "tests/test_connections.py", - "tests/test_aggs.py", - "tests/test_analysis.py", - "tests/test_field.py", - "tests/test_query.py", - "tests/test_utils.py", - "tests/test_wrappers.py", - "examples/vectors.py", - "examples/async/vectors.py", -) - @nox.session( python=[ @@ -93,27 +78,20 @@ def lint(session): @nox.session(python="3.8") def type_check(session): - session.install("mypy", ".[develop]") - errors = [] - popen = subprocess.Popen( - "mypy --strict --implicit-reexport --explicit-package-bases elasticsearch_dsl tests examples", - env=session.env, - shell=True, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, + session.install(".[develop]") + session.run( + "mypy", + "--strict", + "--implicit-reexport", + "--explicit-package-bases", + "elasticsearch_dsl", + "tests", + "examples", + ) + session.run( + "pyright", + "examples", ) - - mypy_output = "" - while popen.poll() is None: - mypy_output += popen.stdout.read(8192).decode() - mypy_output += popen.stdout.read().decode() - - for line in mypy_output.split("\n"): - filepath = line.partition(":")[0] - if filepath.startswith("elasticsearch_dsl/") or filepath in TYPED_FILES: - errors.append(line) - if errors: - session.error("\n" + "\n".join(errors)) @nox.session() diff --git a/setup.py b/setup.py index e5d7b019..464dec56 100644 --- a/setup.py +++ b/setup.py @@ -46,16 +46,19 @@ "pytest-asyncio", "pytz", "coverage", - # typing support - "types-python-dateutil", # the following three are used by the vectors example and its tests "nltk", "sentence_transformers", "tqdm", - "types-tqdm", # Override Read the Docs default (sphinx<2 and sphinx-rtd-theme<0.5) "sphinx>2", "sphinx-rtd-theme>0.5", + # typing support + "mypy", + "pyright", + "types-python-dateutil", + "types-pytz", + "types-tqdm", ] setup( diff --git a/tests/_async/test_document.py b/tests/_async/test_document.py index 67d208bf..45a656b7 100644 --- a/tests/_async/test_document.py +++ b/tests/_async/test_document.py @@ -15,12 +15,18 @@ # specific language governing permissions and limitations # under the License. +# this file creates several documents using bad or no types because +# these are still supported and should be kept functional in spite +# of not having appropriate type hints. For that reason the comment +# below disables many mypy checks that fails as a result of this. +# mypy: disable-error-code="assignment, index, arg-type, call-arg, operator, comparison-overlap, attr-defined" + import codecs import ipaddress import pickle from datetime import datetime from hashlib import md5 -from typing import List, Optional +from typing import Any, Dict, List, Optional import pytest from pytest import raises @@ -94,10 +100,10 @@ class Secret(str): class SecretField(field.CustomField): builtin_type = "text" - def _serialize(self, data): + def _serialize(self, data: Any) -> Any: return codecs.encode(data, "rot_13") - def _deserialize(self, data): + def _deserialize(self, data: Any) -> Any: if isinstance(data, Secret): return data return Secret(codecs.decode(data, "rot_13")) @@ -131,9 +137,9 @@ class Index: name = "test-host" -def test_range_serializes_properly(): +def test_range_serializes_properly() -> None: class D(AsyncDocument): - lr = field.LongRange() + lr: Range[int] = field.LongRange() d = D(lr=Range(lt=42)) assert 40 in d.lr @@ -144,7 +150,7 @@ class D(AsyncDocument): assert {"lr": {"lt": 42}} == d.to_dict() -def test_range_deserializes_properly(): +def test_range_deserializes_properly() -> None: class D(InnerDoc): lr = field.LongRange() @@ -154,13 +160,13 @@ class D(InnerDoc): assert 47 not in d.lr -def test_resolve_nested(): +def test_resolve_nested() -> None: nested, field = NestedSecret._index.resolve_nested("secrets.title") assert nested == ["secrets"] assert field is NestedSecret._doc_type.mapping["secrets"]["title"] -def test_conflicting_mapping_raises_error_in_index_to_dict(): +def test_conflicting_mapping_raises_error_in_index_to_dict() -> None: class A(AsyncDocument): name = field.Text() @@ -175,18 +181,18 @@ class B(AsyncDocument): i.to_dict() -def test_ip_address_serializes_properly(): +def test_ip_address_serializes_properly() -> None: host = Host(ip=ipaddress.IPv4Address("10.0.0.1")) assert {"ip": "10.0.0.1"} == host.to_dict() -def test_matches_uses_index(): +def test_matches_uses_index() -> None: assert SimpleCommit._matches({"_index": "test-git"}) assert not SimpleCommit._matches({"_index": "not-test-git"}) -def test_matches_with_no_name_always_matches(): +def test_matches_with_no_name_always_matches() -> None: class D(AsyncDocument): pass @@ -194,7 +200,7 @@ class D(AsyncDocument): assert D._matches({"_index": "whatever"}) -def test_matches_accepts_wildcards(): +def test_matches_accepts_wildcards() -> None: class MyDoc(AsyncDocument): class Index: name = "my-*" @@ -203,7 +209,7 @@ class Index: assert not MyDoc._matches({"_index": "not-my-index"}) -def test_assigning_attrlist_to_field(): +def test_assigning_attrlist_to_field() -> None: sc = SimpleCommit() l = ["README", "README.rst"] sc.files = utils.AttrList(l) @@ -211,13 +217,13 @@ def test_assigning_attrlist_to_field(): assert sc.to_dict()["files"] is l -def test_optional_inner_objects_are_not_validated_if_missing(): +def test_optional_inner_objects_are_not_validated_if_missing() -> None: d = OptionalObjectWithRequiredField() - assert d.full_clean() is None + d.full_clean() -def test_custom_field(): +def test_custom_field() -> None: s = SecretDoc(title=Secret("Hello")) assert {"title": "Uryyb"} == s.to_dict() @@ -228,13 +234,13 @@ def test_custom_field(): assert isinstance(s.title, Secret) -def test_custom_field_mapping(): +def test_custom_field_mapping() -> None: assert { "properties": {"title": {"index": "no", "type": "text"}} } == SecretDoc._doc_type.mapping.to_dict() -def test_custom_field_in_nested(): +def test_custom_field_in_nested() -> None: s = NestedSecret() s.secrets.append(SecretDoc(title=Secret("Hello"))) @@ -242,7 +248,7 @@ def test_custom_field_in_nested(): assert s.secrets[0].title == "Hello" -def test_multi_works_after_doc_has_been_saved(): +def test_multi_works_after_doc_has_been_saved() -> None: c = SimpleCommit() c.full_clean() c.files.append("setup.py") @@ -250,7 +256,7 @@ def test_multi_works_after_doc_has_been_saved(): assert c.to_dict() == {"files": ["setup.py"]} -def test_multi_works_in_nested_after_doc_has_been_serialized(): +def test_multi_works_in_nested_after_doc_has_been_serialized() -> None: # Issue #359 c = DocWithNested(comments=[Comment(title="First!")]) @@ -259,18 +265,18 @@ def test_multi_works_in_nested_after_doc_has_been_serialized(): assert [] == c.comments[0].tags -def test_null_value_for_object(): +def test_null_value_for_object() -> None: d = MyDoc(inner=None) assert d.inner is None -def test_inherited_doc_types_can_override_index(): +def test_inherited_doc_types_can_override_index() -> None: class MyDocDifferentIndex(MySubDoc): class Index: name = "not-default-index" settings = {"number_of_replicas": 0} - aliases = {"a": {}} + aliases: Dict[str, Any] = {"a": {}} analyzers = [analyzer("my_analizer", tokenizer="keyword")] assert MyDocDifferentIndex._index._name == "not-default-index" @@ -297,7 +303,7 @@ class Index: } -def test_to_dict_with_meta(): +def test_to_dict_with_meta() -> None: d = MySubDoc(title="hello") d.meta.routing = "some-parent" @@ -308,28 +314,28 @@ def test_to_dict_with_meta(): } == d.to_dict(True) -def test_to_dict_with_meta_includes_custom_index(): +def test_to_dict_with_meta_includes_custom_index() -> None: d = MySubDoc(title="hello") d.meta.index = "other-index" assert {"_index": "other-index", "_source": {"title": "hello"}} == d.to_dict(True) -def test_to_dict_without_skip_empty_will_include_empty_fields(): +def test_to_dict_without_skip_empty_will_include_empty_fields() -> None: d = MySubDoc(tags=[], title=None, inner={}) assert {} == d.to_dict() assert {"tags": [], "title": None, "inner": {}} == d.to_dict(skip_empty=False) -def test_attribute_can_be_removed(): +def test_attribute_can_be_removed() -> None: d = MyDoc(title="hello") del d.title assert "title" not in d._d_ -def test_doc_type_can_be_correctly_pickled(): +def test_doc_type_can_be_correctly_pickled() -> None: d = DocWithNested( title="Hello World!", comments=[Comment(title="hellp")], meta={"id": 42} ) @@ -344,7 +350,7 @@ def test_doc_type_can_be_correctly_pickled(): assert isinstance(d2.comments[0], Comment) -def test_meta_is_accessible_even_on_empty_doc(): +def test_meta_is_accessible_even_on_empty_doc() -> None: d = MyDoc() d.meta @@ -352,7 +358,7 @@ def test_meta_is_accessible_even_on_empty_doc(): d.meta -def test_meta_field_mapping(): +def test_meta_field_mapping() -> None: class User(AsyncDocument): username = field.Text() @@ -371,7 +377,7 @@ class Meta: } == User._doc_type.mapping.to_dict() -def test_multi_value_fields(): +def test_multi_value_fields() -> None: class Blog(AsyncDocument): tags = field.Keyword(multi=True) @@ -382,19 +388,19 @@ class Blog(AsyncDocument): assert ["search", "python"] == b.tags -def test_docs_with_properties(): +def test_docs_with_properties() -> None: class User(AsyncDocument): - pwd_hash = field.Text() + pwd_hash: str = field.Text() - def check_password(self, pwd): + def check_password(self, pwd: bytes) -> bool: return md5(pwd).hexdigest() == self.pwd_hash @property - def password(self): + def password(self) -> None: raise AttributeError("readonly") @password.setter - def password(self, pwd): + def password(self, pwd: bytes) -> None: self.pwd_hash = md5(pwd).hexdigest() u = User(pwd_hash=md5(b"secret").hexdigest()) @@ -410,7 +416,7 @@ def password(self, pwd): u.password -def test_nested_can_be_assigned_to(): +def test_nested_can_be_assigned_to() -> None: d1 = DocWithNested(comments=[Comment(title="First!")]) d2 = DocWithNested() @@ -421,13 +427,13 @@ def test_nested_can_be_assigned_to(): assert isinstance(d2.comments[0], Comment) -def test_nested_can_be_none(): +def test_nested_can_be_none() -> None: d = DocWithNested(comments=None, title="Hello World!") assert {"title": "Hello World!"} == d.to_dict() -def test_nested_defaults_to_list_and_can_be_updated(): +def test_nested_defaults_to_list_and_can_be_updated() -> None: md = DocWithNested() assert [] == md.comments @@ -436,7 +442,7 @@ def test_nested_defaults_to_list_and_can_be_updated(): assert {"comments": [{"title": "hello World!"}]} == md.to_dict() -def test_to_dict_is_recursive_and_can_cope_with_multi_values(): +def test_to_dict_is_recursive_and_can_cope_with_multi_values() -> None: md = MyDoc(name=["a", "b", "c"]) md.inner = [MyInner(old_field="of1"), MyInner(old_field="of2")] @@ -448,13 +454,13 @@ def test_to_dict_is_recursive_and_can_cope_with_multi_values(): } == md.to_dict() -def test_to_dict_ignores_empty_collections(): +def test_to_dict_ignores_empty_collections() -> None: md = MySubDoc(name="", address={}, count=0, valid=False, tags=[]) assert {"name": "", "count": 0, "valid": False} == md.to_dict() -def test_declarative_mapping_definition(): +def test_declarative_mapping_definition() -> None: assert issubclass(MyDoc, AsyncDocument) assert hasattr(MyDoc, "_doc_type") assert { @@ -467,7 +473,7 @@ def test_declarative_mapping_definition(): } == MyDoc._doc_type.mapping.to_dict() -def test_you_can_supply_own_mapping_instance(): +def test_you_can_supply_own_mapping_instance() -> None: class MyD(AsyncDocument): title = field.Text() @@ -481,7 +487,7 @@ class Meta: } == MyD._doc_type.mapping.to_dict() -def test_document_can_be_created_dynamically(): +def test_document_can_be_created_dynamically() -> None: n = datetime.now() md = MyDoc(title="hello") md.name = "My Fancy Document!" @@ -502,14 +508,14 @@ def test_document_can_be_created_dynamically(): } == md.to_dict() -def test_invalid_date_will_raise_exception(): +def test_invalid_date_will_raise_exception() -> None: md = MyDoc() md.created_at = "not-a-date" with raises(ValidationException): md.full_clean() -def test_document_inheritance(): +def test_document_inheritance() -> None: assert issubclass(MySubDoc, MyDoc) assert issubclass(MySubDoc, AsyncDocument) assert hasattr(MySubDoc, "_doc_type") @@ -523,7 +529,7 @@ def test_document_inheritance(): } == MySubDoc._doc_type.mapping.to_dict() -def test_child_class_can_override_parent(): +def test_child_class_can_override_parent() -> None: class A(AsyncDocument): o = field.Object(dynamic=False, properties={"a": field.Text()}) @@ -541,7 +547,7 @@ class B(A): } == B._doc_type.mapping.to_dict() -def test_meta_fields_are_stored_in_meta_and_ignored_by_to_dict(): +def test_meta_fields_are_stored_in_meta_and_ignored_by_to_dict() -> None: md = MySubDoc(meta={"id": 42}, name="My First doc!") md.meta.index = "my-index" @@ -551,7 +557,7 @@ def test_meta_fields_are_stored_in_meta_and_ignored_by_to_dict(): assert {"id": 42, "index": "my-index"} == md.meta.to_dict() -def test_index_inheritance(): +def test_index_inheritance() -> None: assert issubclass(MyMultiSubDoc, MySubDoc) assert issubclass(MyMultiSubDoc, MyDoc2) assert issubclass(MyMultiSubDoc, AsyncDocument) @@ -568,7 +574,7 @@ def test_index_inheritance(): } == MyMultiSubDoc._doc_type.mapping.to_dict() -def test_meta_fields_can_be_set_directly_in_init(): +def test_meta_fields_can_be_set_directly_in_init() -> None: p = object() md = MyDoc(_id=p, title="Hello World!") @@ -576,27 +582,27 @@ def test_meta_fields_can_be_set_directly_in_init(): @pytest.mark.asyncio -async def test_save_no_index(async_mock_client): +async def test_save_no_index(async_mock_client: Any) -> None: md = MyDoc() with raises(ValidationException): await md.save(using="mock") @pytest.mark.asyncio -async def test_delete_no_index(async_mock_client): +async def test_delete_no_index(async_mock_client: Any) -> None: md = MyDoc() with raises(ValidationException): await md.delete(using="mock") @pytest.mark.asyncio -async def test_update_no_fields(): +async def test_update_no_fields() -> None: md = MyDoc() with raises(IllegalOperation): await md.update() -def test_search_with_custom_alias_and_index(): +def test_search_with_custom_alias_and_index() -> None: search_object = MyDoc.search( using="staging", index=["custom_index1", "custom_index2"] ) @@ -605,7 +611,7 @@ def test_search_with_custom_alias_and_index(): assert search_object._index == ["custom_index1", "custom_index2"] -def test_from_es_respects_underscored_non_meta_fields(): +def test_from_es_respects_underscored_non_meta_fields() -> None: doc = { "_index": "test-index", "_id": "elasticsearch", @@ -629,7 +635,7 @@ class Index: assert c._tagline == "You know, for search" -def test_nested_and_object_inner_doc(): +def test_nested_and_object_inner_doc() -> None: class MySubDocWithNested(MyDoc): nested_inner = field.Nested(MyInner) @@ -646,7 +652,7 @@ class MySubDocWithNested(MyDoc): } -def test_doc_with_type_hints(): +def test_doc_with_type_hints() -> None: class TypedInnerDoc(InnerDoc): st: M[str] dt: M[Optional[datetime]] @@ -656,16 +662,16 @@ class TypedDoc(AsyncDocument): st: str dt: Optional[datetime] li: List[int] - ob: Optional[TypedInnerDoc] - ns: Optional[List[TypedInnerDoc]] + ob: TypedInnerDoc + ns: List[TypedInnerDoc] ip: Optional[str] = field.Ip() k1: str = field.Keyword(required=True) k2: M[str] = field.Keyword() k3: str = mapped_field(field.Keyword(), default="foo") - k4: M[Optional[str]] = mapped_field(field.Keyword()) + k4: M[Optional[str]] = mapped_field(field.Keyword()) # type: ignore[misc] s1: Secret = SecretField() s2: M[Secret] = SecretField() - s3: Secret = mapped_field(SecretField()) + s3: Secret = mapped_field(SecretField()) # type: ignore[misc] s4: M[Optional[Secret]] = mapped_field( SecretField(), default_factory=lambda: "foo" ) @@ -707,27 +713,34 @@ class TypedDoc(AsyncDocument): assert doc.s4 == "foo" with raises(ValidationException) as exc_info: doc.full_clean() - assert set(exc_info.value.args[0].keys()) == {"st", "k1", "k2", "s1", "s2", "s3"} + assert set(exc_info.value.args[0].keys()) == { + "st", + "k1", + "k2", + "ob", + "s1", + "s2", + "s3", + } doc.st = "s" doc.li = [1, 2, 3] doc.k1 = "k1" doc.k2 = "k2" + doc.ob.st = "s" + doc.ob.li = [1] doc.s1 = "s1" doc.s2 = "s2" doc.s3 = "s3" doc.full_clean() - doc.ob = TypedInnerDoc() + doc.ob = TypedInnerDoc(li=[1]) with raises(ValidationException) as exc_info: doc.full_clean() assert set(exc_info.value.args[0].keys()) == {"ob"} assert set(exc_info.value.args[0]["ob"][0].args[0].keys()) == {"st"} doc.ob.st = "s" - doc.ob.li = [1] - doc.full_clean() - doc.ns.append(TypedInnerDoc(li=[1, 2])) with raises(ValidationException) as exc_info: doc.full_clean() @@ -766,7 +779,7 @@ class TypedDoc(AsyncDocument): assert s.to_dict() == {"sort": ["st", {"dt": {"order": "desc"}}, "ob.st"]} -def test_instrumented_field(): +def test_instrumented_field() -> None: class Child(InnerDoc): st: M[str] diff --git a/tests/_async/test_faceted_search.py b/tests/_async/test_faceted_search.py index 0a218887..c1d38af8 100644 --- a/tests/_async/test_faceted_search.py +++ b/tests/_async/test_faceted_search.py @@ -28,10 +28,10 @@ class BlogSearch(AsyncFacetedSearch): doc_types = ["user", "post"] - fields = ( + fields = [ "title^5", "body", - ) + ] facets = { "category": TermsFacet(field="category.raw"), @@ -39,7 +39,7 @@ class BlogSearch(AsyncFacetedSearch): } -def test_query_is_created_properly(): +def test_query_is_created_properly() -> None: bs = BlogSearch("python search") s = bs.build_search() @@ -56,13 +56,13 @@ def test_query_is_created_properly(): }, }, "query": { - "multi_match": {"fields": ("title^5", "body"), "query": "python search"} + "multi_match": {"fields": ["title^5", "body"], "query": "python search"} }, "highlight": {"fields": {"body": {}, "title": {}}}, } == s.to_dict() -def test_query_is_created_properly_with_sort_tuple(): +def test_query_is_created_properly_with_sort_tuple() -> None: bs = BlogSearch("python search", sort=("category", "-title")) s = bs.build_search() @@ -79,14 +79,14 @@ def test_query_is_created_properly_with_sort_tuple(): }, }, "query": { - "multi_match": {"fields": ("title^5", "body"), "query": "python search"} + "multi_match": {"fields": ["title^5", "body"], "query": "python search"} }, "highlight": {"fields": {"body": {}, "title": {}}}, "sort": ["category", {"title": {"order": "desc"}}], } == s.to_dict() -def test_filter_is_applied_to_search_but_not_relevant_facet(): +def test_filter_is_applied_to_search_but_not_relevant_facet() -> None: bs = BlogSearch("python search", filters={"category": "elastic"}) s = bs.build_search() @@ -103,13 +103,13 @@ def test_filter_is_applied_to_search_but_not_relevant_facet(): }, "post_filter": {"terms": {"category.raw": ["elastic"]}}, "query": { - "multi_match": {"fields": ("title^5", "body"), "query": "python search"} + "multi_match": {"fields": ["title^5", "body"], "query": "python search"} }, "highlight": {"fields": {"body": {}, "title": {}}}, } == s.to_dict() -def test_filters_are_applied_to_search_ant_relevant_facets(): +def test_filters_are_applied_to_search_ant_relevant_facets() -> None: bs = BlogSearch( "python search", filters={"category": "elastic", "tags": ["python", "django"]} ) @@ -135,17 +135,17 @@ def test_filters_are_applied_to_search_ant_relevant_facets(): }, }, "query": { - "multi_match": {"fields": ("title^5", "body"), "query": "python search"} + "multi_match": {"fields": ["title^5", "body"], "query": "python search"} }, "post_filter": {"bool": {}}, "highlight": {"fields": {"body": {}, "title": {}}}, } == d -def test_date_histogram_facet_with_1970_01_01_date(): +def test_date_histogram_facet_with_1970_01_01_date() -> None: dhf = DateHistogramFacet() - assert dhf.get_value({"key": None}) == datetime(1970, 1, 1, 0, 0) - assert dhf.get_value({"key": 0}) == datetime(1970, 1, 1, 0, 0) + assert dhf.get_value({"key": None}) == datetime(1970, 1, 1, 0, 0) # type: ignore[arg-type] + assert dhf.get_value({"key": 0}) == datetime(1970, 1, 1, 0, 0) # type: ignore[arg-type] @pytest.mark.parametrize( @@ -175,7 +175,7 @@ def test_date_histogram_facet_with_1970_01_01_date(): ("fixed_interval", "1h"), ], ) -def test_date_histogram_interval_types(interval_type, interval): +def test_date_histogram_interval_types(interval_type: str, interval: str) -> None: dhf = DateHistogramFacet(field="@timestamp", **{interval_type: interval}) assert dhf.get_aggregation().to_dict() == { "date_histogram": { @@ -187,14 +187,14 @@ def test_date_histogram_interval_types(interval_type, interval): dhf.get_value_filter(datetime.now()) -def test_date_histogram_no_interval_keyerror(): +def test_date_histogram_no_interval_keyerror() -> None: dhf = DateHistogramFacet(field="@timestamp") with pytest.raises(KeyError) as e: dhf.get_value_filter(datetime.now()) assert str(e.value) == "'interval'" -def test_params_added_to_search(): +def test_params_added_to_search() -> None: bs = BlogSearch("python search") assert bs._s._params == {} bs.params(routing="42") diff --git a/tests/_async/test_index.py b/tests/_async/test_index.py index a6e87776..c742a09b 100644 --- a/tests/_async/test_index.py +++ b/tests/_async/test_index.py @@ -17,6 +17,7 @@ import string from random import choice +from typing import Any, Dict import pytest from pytest import raises @@ -36,7 +37,7 @@ class Post(AsyncDocument): published_from = Date() -def test_multiple_doc_types_will_combine_mappings(): +def test_multiple_doc_types_will_combine_mappings() -> None: class User(AsyncDocument): username = Text() @@ -54,16 +55,16 @@ class User(AsyncDocument): } == i.to_dict() -def test_search_is_limited_to_index_name(): +def test_search_is_limited_to_index_name() -> None: i = AsyncIndex("my-index") s = i.search() assert s._index == ["my-index"] -def test_cloned_index_has_copied_settings_and_using(): +def test_cloned_index_has_copied_settings_and_using() -> None: client = object() - i = AsyncIndex("my-index", using=client) + i = AsyncIndex("my-index", using=client) # type: ignore[arg-type] i.settings(number_of_shards=1) i2 = i.clone("my-other-index") @@ -74,13 +75,13 @@ def test_cloned_index_has_copied_settings_and_using(): assert i._settings is not i2._settings -def test_cloned_index_has_analysis_attribute(): +def test_cloned_index_has_analysis_attribute() -> None: """ Regression test for Issue #582 in which `AsyncIndex.clone()` was not copying over the `_analysis` attribute. """ client = object() - i = AsyncIndex("my-index", using=client) + i = AsyncIndex("my-index", using=client) # type: ignore[arg-type] random_analyzer_name = "".join(choice(string.ascii_letters) for _ in range(100)) random_analyzer = analyzer( @@ -94,7 +95,7 @@ def test_cloned_index_has_analysis_attribute(): assert i.to_dict()["settings"]["analysis"] == i2.to_dict()["settings"]["analysis"] -def test_settings_are_saved(): +def test_settings_are_saved() -> None: i = AsyncIndex("i") i.settings(number_of_replicas=0) i.settings(number_of_shards=1) @@ -102,7 +103,7 @@ def test_settings_are_saved(): assert {"settings": {"number_of_shards": 1, "number_of_replicas": 0}} == i.to_dict() -def test_registered_doc_type_included_in_to_dict(): +def test_registered_doc_type_included_in_to_dict() -> None: i = AsyncIndex("i", using="alias") i.document(Post) @@ -116,7 +117,7 @@ def test_registered_doc_type_included_in_to_dict(): } == i.to_dict() -def test_registered_doc_type_included_in_search(): +def test_registered_doc_type_included_in_search() -> None: i = AsyncIndex("i", using="alias") i.document(Post) @@ -125,9 +126,9 @@ def test_registered_doc_type_included_in_search(): assert s._doc_type == [Post] -def test_aliases_add_to_object(): +def test_aliases_add_to_object() -> None: random_alias = "".join(choice(string.ascii_letters) for _ in range(100)) - alias_dict = {random_alias: {}} + alias_dict: Dict[str, Any] = {random_alias: {}} index = AsyncIndex("i", using="alias") index.aliases(**alias_dict) @@ -135,9 +136,9 @@ def test_aliases_add_to_object(): assert index._aliases == alias_dict -def test_aliases_returned_from_to_dict(): +def test_aliases_returned_from_to_dict() -> None: random_alias = "".join(choice(string.ascii_letters) for _ in range(100)) - alias_dict = {random_alias: {}} + alias_dict: Dict[str, Any] = {random_alias: {}} index = AsyncIndex("i", using="alias") index.aliases(**alias_dict) @@ -145,7 +146,7 @@ def test_aliases_returned_from_to_dict(): assert index._aliases == index.to_dict()["aliases"] == alias_dict -def test_analyzers_added_to_object(): +def test_analyzers_added_to_object() -> None: random_analyzer_name = "".join(choice(string.ascii_letters) for _ in range(100)) random_analyzer = analyzer( random_analyzer_name, tokenizer="standard", filter="standard" @@ -161,7 +162,7 @@ def test_analyzers_added_to_object(): } -def test_analyzers_returned_from_to_dict(): +def test_analyzers_returned_from_to_dict() -> None: random_analyzer_name = "".join(choice(string.ascii_letters) for _ in range(100)) random_analyzer = analyzer( random_analyzer_name, tokenizer="standard", filter="standard" @@ -174,7 +175,7 @@ def test_analyzers_returned_from_to_dict(): ] == {"filter": ["standard"], "type": "custom", "tokenizer": "standard"} -def test_conflicting_analyzer_raises_error(): +def test_conflicting_analyzer_raises_error() -> None: i = AsyncIndex("i") i.analyzer("my_analyzer", tokenizer="whitespace", filter=["lowercase", "stop"]) @@ -182,7 +183,7 @@ def test_conflicting_analyzer_raises_error(): i.analyzer("my_analyzer", tokenizer="keyword", filter=["lowercase", "stop"]) -def test_index_template_can_have_order(): +def test_index_template_can_have_order() -> None: i = AsyncIndex("i-*") it = i.as_template("i", order=2) @@ -190,7 +191,7 @@ def test_index_template_can_have_order(): @pytest.mark.asyncio -async def test_index_template_save_result(async_mock_client): +async def test_index_template_save_result(async_mock_client: Any) -> None: it = AsyncIndexTemplate("test-template", "test-*") assert await it.save(using="mock") == await async_mock_client.indices.put_template() diff --git a/tests/_async/test_mapping.py b/tests/_async/test_mapping.py index 6d47901c..f3b8538d 100644 --- a/tests/_async/test_mapping.py +++ b/tests/_async/test_mapping.py @@ -20,7 +20,7 @@ from elasticsearch_dsl import AsyncMapping, Keyword, Nested, Text, analysis -def test_mapping_can_has_fields(): +def test_mapping_can_has_fields() -> None: m = AsyncMapping() m.field("name", "text").field("tags", "keyword") @@ -29,7 +29,7 @@ def test_mapping_can_has_fields(): } == m.to_dict() -def test_mapping_update_is_recursive(): +def test_mapping_update_is_recursive() -> None: m1 = AsyncMapping() m1.field("title", "text") m1.field("author", "object") @@ -62,7 +62,7 @@ def test_mapping_update_is_recursive(): } == m1.to_dict() -def test_properties_can_iterate_over_all_the_fields(): +def test_properties_can_iterate_over_all_the_fields() -> None: m = AsyncMapping() m.field("f1", "text", test_attr="f1", fields={"f2": Keyword(test_attr="f2")}) m.field("f3", Nested(test_attr="f3", properties={"f4": Text(test_attr="f4")})) @@ -72,7 +72,7 @@ def test_properties_can_iterate_over_all_the_fields(): } -def test_mapping_can_collect_all_analyzers_and_normalizers(): +def test_mapping_can_collect_all_analyzers_and_normalizers() -> None: a1 = analysis.analyzer( "my_analyzer1", tokenizer="keyword", @@ -145,7 +145,7 @@ def test_mapping_can_collect_all_analyzers_and_normalizers(): assert json.loads(json.dumps(m.to_dict())) == m.to_dict() -def test_mapping_can_collect_multiple_analyzers(): +def test_mapping_can_collect_multiple_analyzers() -> None: a1 = analysis.analyzer( "my_analyzer1", tokenizer="keyword", @@ -191,7 +191,7 @@ def test_mapping_can_collect_multiple_analyzers(): } == m._collect_analysis() -def test_even_non_custom_analyzers_can_have_params(): +def test_even_non_custom_analyzers_can_have_params() -> None: a1 = analysis.analyzer("whitespace", type="pattern", pattern=r"\\s+") m = AsyncMapping() m.field("title", "text", analyzer=a1) @@ -201,14 +201,14 @@ def test_even_non_custom_analyzers_can_have_params(): } == m._collect_analysis() -def test_resolve_field_can_resolve_multifields(): +def test_resolve_field_can_resolve_multifields() -> None: m = AsyncMapping() m.field("title", "text", fields={"keyword": Keyword()}) assert isinstance(m.resolve_field("title.keyword"), Keyword) -def test_resolve_nested(): +def test_resolve_nested() -> None: m = AsyncMapping() m.field("n1", "nested", properties={"n2": Nested(properties={"k1": Keyword()})}) m.field("k2", "keyword") diff --git a/tests/_async/test_search.py b/tests/_async/test_search.py index fb60985c..603c3826 100644 --- a/tests/_async/test_search.py +++ b/tests/_async/test_search.py @@ -16,61 +16,62 @@ # under the License. from copy import deepcopy +from typing import Any import pytest from pytest import raises -from elasticsearch_dsl import A, AsyncEmptySearch, AsyncSearch, Document, Q, query +from elasticsearch_dsl import AsyncEmptySearch, AsyncSearch, Document, Q, query from elasticsearch_dsl.exceptions import IllegalOperation -def test_expand__to_dot_is_respected(): +def test_expand__to_dot_is_respected() -> None: s = AsyncSearch().query("match", a__b=42, _expand__to_dot=False) assert {"query": {"match": {"a__b": 42}}} == s.to_dict() @pytest.mark.asyncio -async def test_execute_uses_cache(): +async def test_execute_uses_cache() -> None: s = AsyncSearch() r = object() - s._response = r + s._response = r # type: ignore[assignment] assert r is await s.execute() @pytest.mark.asyncio -async def test_cache_can_be_ignored(async_mock_client): +async def test_cache_can_be_ignored(async_mock_client: Any) -> None: s = AsyncSearch(using="mock") r = object() - s._response = r + s._response = r # type: ignore[assignment] await s.execute(ignore_cache=True) async_mock_client.search.assert_awaited_once_with(index=None, body={}) @pytest.mark.asyncio -async def test_iter_iterates_over_hits(): +async def test_iter_iterates_over_hits() -> None: s = AsyncSearch() - s._response = [1, 2, 3] + s._response = [1, 2, 3] # type: ignore[assignment] assert [1, 2, 3] == [hit async for hit in s] -def test_cache_isnt_cloned(): +def test_cache_isnt_cloned() -> None: s = AsyncSearch() - s._response = object() + s._response = object() # type: ignore[assignment] assert not hasattr(s._clone(), "_response") -def test_search_starts_with_no_query(): +def test_search_starts_with_no_query() -> None: s = AsyncSearch() assert s.query._proxied is None -def test_search_query_combines_query(): +def test_search_query_combines_query() -> None: s = AsyncSearch() s2 = s.query("match", f=42) @@ -82,19 +83,19 @@ def test_search_query_combines_query(): assert s3.query._proxied == query.Bool(must=[query.Match(f=42), query.Match(f=43)]) -def test_query_can_be_assigned_to(): +def test_query_can_be_assigned_to() -> None: s = AsyncSearch() q = Q("match", title="python") - s.query = q + s.query = q # type: ignore assert s.query._proxied is q -def test_query_can_be_wrapped(): +def test_query_can_be_wrapped() -> None: s = AsyncSearch().query("match", title="python") - s.query = Q("function_score", query=s.query, field_value_factor={"field": "rating"}) + s.query = Q("function_score", query=s.query, field_value_factor={"field": "rating"}) # type: ignore assert { "query": { @@ -106,29 +107,29 @@ def test_query_can_be_wrapped(): } == s.to_dict() -def test_using(): +def test_using() -> None: o = object() o2 = object() s = AsyncSearch(using=o) assert s._using is o - s2 = s.using(o2) + s2 = s.using(o2) # type: ignore[arg-type] assert s._using is o assert s2._using is o2 -def test_methods_are_proxied_to_the_query(): +def test_methods_are_proxied_to_the_query() -> None: s = AsyncSearch().query("match_all") assert s.query.to_dict() == {"match_all": {}} -def test_query_always_returns_search(): +def test_query_always_returns_search() -> None: s = AsyncSearch() assert isinstance(s.query("match", f=42), AsyncSearch) -def test_source_copied_on_clone(): +def test_source_copied_on_clone() -> None: s = AsyncSearch().source(False) assert s._clone()._source == s._source assert s._clone()._source is False @@ -142,7 +143,7 @@ def test_source_copied_on_clone(): assert s3._clone()._source == ["some", "fields"] -def test_copy_clones(): +def test_copy_clones() -> None: from copy import copy s1 = AsyncSearch().source(["some", "fields"]) @@ -152,7 +153,7 @@ def test_copy_clones(): assert s1 is not s2 -def test_aggs_allow_two_metric(): +def test_aggs_allow_two_metric() -> None: s = AsyncSearch() s.aggs.metric("a", "max", field="a").metric("b", "max", field="b") @@ -162,7 +163,7 @@ def test_aggs_allow_two_metric(): } -def test_aggs_get_copied_on_change(): +def test_aggs_get_copied_on_change() -> None: s = AsyncSearch().query("match_all") s.aggs.bucket("per_tag", "terms", field="f").metric( "max_score", "max", field="score" @@ -175,7 +176,7 @@ def test_aggs_get_copied_on_change(): s4 = s3._clone() s4.aggs.metric("max_score", "max", field="score") - d = { + d: Any = { "query": {"match_all": {}}, "aggs": { "per_tag": { @@ -194,7 +195,7 @@ def test_aggs_get_copied_on_change(): assert d == s4.to_dict() -def test_search_index(): +def test_search_index() -> None: s = AsyncSearch(index="i") assert s._index == ["i"] s = s.index("i2") @@ -225,7 +226,7 @@ def test_search_index(): assert s2._index == ["i", "i2", "i3", "i4", "i5"] -def test_doc_type_document_class(): +def test_doc_type_document_class() -> None: class MyDocument(Document): pass @@ -238,15 +239,15 @@ class MyDocument(Document): assert s._doc_type_map == {} -def test_knn(): +def test_knn() -> None: s = AsyncSearch() with raises(TypeError): - s.knn() + s.knn() # type: ignore[call-arg] with raises(TypeError): - s.knn("field") + s.knn("field") # type: ignore[call-arg] with raises(TypeError): - s.knn("field", 5) + s.knn("field", 5) # type: ignore[call-arg] with raises(ValueError): s.knn("field", 5, 100) with raises(ValueError): @@ -294,7 +295,7 @@ def test_knn(): } == s.to_dict() -def test_rank(): +def test_rank() -> None: s = AsyncSearch() s.rank(rrf=False) assert {} == s.to_dict() @@ -306,7 +307,7 @@ def test_rank(): assert {"rank": {"rrf": {"window_size": 50, "rank_constant": 20}}} == s.to_dict() -def test_sort(): +def test_sort() -> None: s = AsyncSearch() s = s.sort("fielda", "-fieldb") @@ -318,7 +319,7 @@ def test_sort(): assert AsyncSearch().to_dict() == s.to_dict() -def test_sort_by_score(): +def test_sort_by_score() -> None: s = AsyncSearch() s = s.sort("_score") assert {"sort": ["_score"]} == s.to_dict() @@ -328,7 +329,7 @@ def test_sort_by_score(): s.sort("-_score") -def test_collapse(): +def test_collapse() -> None: s = AsyncSearch() inner_hits = {"name": "most_recent", "size": 5, "sort": [{"@timestamp": "desc"}]} @@ -360,7 +361,7 @@ def test_collapse(): assert AsyncSearch().to_dict() == s.to_dict() -def test_slice(): +def test_slice() -> None: s = AsyncSearch() assert {"from": 3, "size": 7} == s[3:10].to_dict() assert {"size": 5} == s[:5].to_dict() @@ -383,7 +384,7 @@ def test_slice(): s[-3:-2] -def test_index(): +def test_index() -> None: s = AsyncSearch() assert {"from": 3, "size": 1} == s[3].to_dict() assert {"from": 3, "size": 1} == s[3][0].to_dict() @@ -393,7 +394,7 @@ def test_index(): s[-3] -def test_search_to_dict(): +def test_search_to_dict() -> None: s = AsyncSearch() assert {} == s.to_dict() @@ -422,7 +423,7 @@ def test_search_to_dict(): assert {"size": 5, "from": 42} == s.to_dict() -def test_complex_example(): +def test_complex_example() -> None: s = AsyncSearch() s = ( s.query("match", title="python") @@ -475,7 +476,7 @@ def test_complex_example(): } == s.to_dict() -def test_reverse(): +def test_reverse() -> None: d = { "query": { "filtered": { @@ -525,14 +526,14 @@ def test_reverse(): assert d == s.to_dict() -def test_from_dict_doesnt_need_query(): +def test_from_dict_doesnt_need_query() -> None: s = AsyncSearch.from_dict({"size": 5}) assert {"size": 5} == s.to_dict() @pytest.mark.asyncio -async def test_params_being_passed_to_search(async_mock_client): +async def test_params_being_passed_to_search(async_mock_client: Any) -> None: s = AsyncSearch(using="mock") s = s.params(routing="42") await s.execute() @@ -540,7 +541,7 @@ async def test_params_being_passed_to_search(async_mock_client): async_mock_client.search.assert_awaited_once_with(index=None, body={}, routing="42") -def test_source(): +def test_source() -> None: assert {} == AsyncSearch().source().to_dict() assert { @@ -554,7 +555,7 @@ def test_source(): ).source(["f1", "f2"]).to_dict() -def test_source_on_clone(): +def test_source_on_clone() -> None: assert { "_source": {"includes": ["foo.bar.*"], "excludes": ["foo.one"]}, "query": {"bool": {"filter": [{"term": {"title": "python"}}]}}, @@ -569,7 +570,7 @@ def test_source_on_clone(): } == AsyncSearch().source(False).filter("term", title="python").to_dict() -def test_source_on_clear(): +def test_source_on_clear() -> None: assert ( {} == AsyncSearch() @@ -579,7 +580,7 @@ def test_source_on_clear(): ) -def test_suggest_accepts_global_text(): +def test_suggest_accepts_global_text() -> None: s = AsyncSearch.from_dict( { "suggest": { @@ -601,7 +602,7 @@ def test_suggest_accepts_global_text(): } == s.to_dict() -def test_suggest(): +def test_suggest() -> None: s = AsyncSearch() s = s.suggest("my_suggestion", "pyhton", term={"field": "title"}) @@ -610,7 +611,7 @@ def test_suggest(): } == s.to_dict() -def test_exclude(): +def test_exclude() -> None: s = AsyncSearch() s = s.exclude("match", title="python") @@ -624,7 +625,7 @@ def test_exclude(): @pytest.mark.asyncio -async def test_delete_by_query(async_mock_client): +async def test_delete_by_query(async_mock_client: Any) -> None: s = AsyncSearch(using="mock", index="i").query("match", lang="java") await s.delete() @@ -633,7 +634,7 @@ async def test_delete_by_query(async_mock_client): ) -def test_update_from_dict(): +def test_update_from_dict() -> None: s = AsyncSearch() s.update_from_dict({"indices_boost": [{"important-documents": 2}]}) s.update_from_dict({"_source": ["id", "name"]}) @@ -646,7 +647,7 @@ def test_update_from_dict(): } == s.to_dict() -def test_rescore_query_to_dict(): +def test_rescore_query_to_dict() -> None: s = AsyncSearch(index="index-name") positive_query = Q( @@ -709,10 +710,10 @@ def test_rescore_query_to_dict(): @pytest.mark.asyncio -async def test_empty_search(): +async def test_empty_search() -> None: s = AsyncEmptySearch(index="index-name") s = s.query("match", lang="java") - s.aggs.bucket("versions", A("terms", field="version")) + s.aggs.bucket("versions", "terms", field="version") assert await s.count() == 0 assert [hit async for hit in s] == [] @@ -720,7 +721,7 @@ async def test_empty_search(): await s.delete() # should not error -def test_suggest_completion(): +def test_suggest_completion() -> None: s = AsyncSearch() s = s.suggest("my_suggestion", "pyhton", completion={"field": "title"}) @@ -731,7 +732,7 @@ def test_suggest_completion(): } == s.to_dict() -def test_suggest_regex_query(): +def test_suggest_regex_query() -> None: s = AsyncSearch() s = s.suggest("my_suggestion", regex="py[thon|py]", completion={"field": "title"}) @@ -742,19 +743,19 @@ def test_suggest_regex_query(): } == s.to_dict() -def test_suggest_must_pass_text_or_regex(): +def test_suggest_must_pass_text_or_regex() -> None: s = AsyncSearch() with raises(ValueError): s.suggest("my_suggestion") -def test_suggest_can_only_pass_text_or_regex(): +def test_suggest_can_only_pass_text_or_regex() -> None: s = AsyncSearch() with raises(ValueError): s.suggest("my_suggestion", text="python", regex="py[hton|py]") -def test_suggest_regex_must_be_wtih_completion(): +def test_suggest_regex_must_be_wtih_completion() -> None: s = AsyncSearch() with raises(ValueError): s.suggest("my_suggestion", regex="py[thon|py]") diff --git a/tests/_async/test_update_by_query.py b/tests/_async/test_update_by_query.py index c62380b6..4bde5ee0 100644 --- a/tests/_async/test_update_by_query.py +++ b/tests/_async/test_update_by_query.py @@ -16,20 +16,22 @@ # under the License. from copy import deepcopy +from typing import Any import pytest from elasticsearch_dsl import AsyncUpdateByQuery, Q from elasticsearch_dsl.response import UpdateByQueryResponse +from elasticsearch_dsl.search_base import SearchBase -def test_ubq_starts_with_no_query(): +def test_ubq_starts_with_no_query() -> None: ubq = AsyncUpdateByQuery() assert ubq.query._proxied is None -def test_ubq_to_dict(): +def test_ubq_to_dict() -> None: ubq = AsyncUpdateByQuery() assert {} == ubq.to_dict() @@ -45,7 +47,7 @@ def test_ubq_to_dict(): assert {"extra_q": {"term": {"category": "conference"}}} == ubq.to_dict() -def test_complex_example(): +def test_complex_example() -> None: ubq = AsyncUpdateByQuery() ubq = ( ubq.query("match", title="python") @@ -83,7 +85,7 @@ def test_complex_example(): } == ubq.to_dict() -def test_exclude(): +def test_exclude() -> None: ubq = AsyncUpdateByQuery() ubq = ubq.exclude("match", title="python") @@ -96,7 +98,7 @@ def test_exclude(): } == ubq.to_dict() -def test_reverse(): +def test_reverse() -> None: d = { "query": { "filtered": { @@ -132,14 +134,14 @@ def test_reverse(): assert d == ubq.to_dict() -def test_from_dict_doesnt_need_query(): +def test_from_dict_doesnt_need_query() -> None: ubq = AsyncUpdateByQuery.from_dict({"script": {"source": "test"}}) assert {"script": {"source": "test"}} == ubq.to_dict() @pytest.mark.asyncio -async def test_params_being_passed_to_search(async_mock_client): +async def test_params_being_passed_to_search(async_mock_client: Any) -> None: ubq = AsyncUpdateByQuery(using="mock", index="i") ubq = ubq.params(routing="42") await ubq.execute() @@ -147,7 +149,7 @@ async def test_params_being_passed_to_search(async_mock_client): async_mock_client.update_by_query.assert_called_once_with(index=["i"], routing="42") -def test_overwrite_script(): +def test_overwrite_script() -> None: ubq = AsyncUpdateByQuery() ubq = ubq.script( source="ctx._source.likes += params.f", lang="painless", params={"f": 3} @@ -163,12 +165,12 @@ def test_overwrite_script(): assert {"script": {"source": "ctx._source.likes++"}} == ubq.to_dict() -def test_update_by_query_response_success(): - ubqr = UpdateByQueryResponse({}, {"timed_out": False, "failures": []}) +def test_update_by_query_response_success() -> None: + ubqr = UpdateByQueryResponse(SearchBase(), {"timed_out": False, "failures": []}) assert ubqr.success() - ubqr = UpdateByQueryResponse({}, {"timed_out": True, "failures": []}) + ubqr = UpdateByQueryResponse(SearchBase(), {"timed_out": True, "failures": []}) assert not ubqr.success() - ubqr = UpdateByQueryResponse({}, {"timed_out": False, "failures": [{}]}) + ubqr = UpdateByQueryResponse(SearchBase(), {"timed_out": False, "failures": [{}]}) assert not ubqr.success() diff --git a/tests/_sync/test_document.py b/tests/_sync/test_document.py index 99153762..0ba3e419 100644 --- a/tests/_sync/test_document.py +++ b/tests/_sync/test_document.py @@ -15,12 +15,18 @@ # specific language governing permissions and limitations # under the License. +# this file creates several documents using bad or no types because +# these are still supported and should be kept functional in spite +# of not having appropriate type hints. For that reason the comment +# below disables many mypy checks that fails as a result of this. +# mypy: disable-error-code="assignment, index, arg-type, call-arg, operator, comparison-overlap, attr-defined" + import codecs import ipaddress import pickle from datetime import datetime from hashlib import md5 -from typing import List, Optional +from typing import Any, Dict, List, Optional import pytest from pytest import raises @@ -94,10 +100,10 @@ class Secret(str): class SecretField(field.CustomField): builtin_type = "text" - def _serialize(self, data): + def _serialize(self, data: Any) -> Any: return codecs.encode(data, "rot_13") - def _deserialize(self, data): + def _deserialize(self, data: Any) -> Any: if isinstance(data, Secret): return data return Secret(codecs.decode(data, "rot_13")) @@ -131,9 +137,9 @@ class Index: name = "test-host" -def test_range_serializes_properly(): +def test_range_serializes_properly() -> None: class D(Document): - lr = field.LongRange() + lr: Range[int] = field.LongRange() d = D(lr=Range(lt=42)) assert 40 in d.lr @@ -144,7 +150,7 @@ class D(Document): assert {"lr": {"lt": 42}} == d.to_dict() -def test_range_deserializes_properly(): +def test_range_deserializes_properly() -> None: class D(InnerDoc): lr = field.LongRange() @@ -154,13 +160,13 @@ class D(InnerDoc): assert 47 not in d.lr -def test_resolve_nested(): +def test_resolve_nested() -> None: nested, field = NestedSecret._index.resolve_nested("secrets.title") assert nested == ["secrets"] assert field is NestedSecret._doc_type.mapping["secrets"]["title"] -def test_conflicting_mapping_raises_error_in_index_to_dict(): +def test_conflicting_mapping_raises_error_in_index_to_dict() -> None: class A(Document): name = field.Text() @@ -175,18 +181,18 @@ class B(Document): i.to_dict() -def test_ip_address_serializes_properly(): +def test_ip_address_serializes_properly() -> None: host = Host(ip=ipaddress.IPv4Address("10.0.0.1")) assert {"ip": "10.0.0.1"} == host.to_dict() -def test_matches_uses_index(): +def test_matches_uses_index() -> None: assert SimpleCommit._matches({"_index": "test-git"}) assert not SimpleCommit._matches({"_index": "not-test-git"}) -def test_matches_with_no_name_always_matches(): +def test_matches_with_no_name_always_matches() -> None: class D(Document): pass @@ -194,7 +200,7 @@ class D(Document): assert D._matches({"_index": "whatever"}) -def test_matches_accepts_wildcards(): +def test_matches_accepts_wildcards() -> None: class MyDoc(Document): class Index: name = "my-*" @@ -203,7 +209,7 @@ class Index: assert not MyDoc._matches({"_index": "not-my-index"}) -def test_assigning_attrlist_to_field(): +def test_assigning_attrlist_to_field() -> None: sc = SimpleCommit() l = ["README", "README.rst"] sc.files = utils.AttrList(l) @@ -211,13 +217,13 @@ def test_assigning_attrlist_to_field(): assert sc.to_dict()["files"] is l -def test_optional_inner_objects_are_not_validated_if_missing(): +def test_optional_inner_objects_are_not_validated_if_missing() -> None: d = OptionalObjectWithRequiredField() - assert d.full_clean() is None + d.full_clean() -def test_custom_field(): +def test_custom_field() -> None: s = SecretDoc(title=Secret("Hello")) assert {"title": "Uryyb"} == s.to_dict() @@ -228,13 +234,13 @@ def test_custom_field(): assert isinstance(s.title, Secret) -def test_custom_field_mapping(): +def test_custom_field_mapping() -> None: assert { "properties": {"title": {"index": "no", "type": "text"}} } == SecretDoc._doc_type.mapping.to_dict() -def test_custom_field_in_nested(): +def test_custom_field_in_nested() -> None: s = NestedSecret() s.secrets.append(SecretDoc(title=Secret("Hello"))) @@ -242,7 +248,7 @@ def test_custom_field_in_nested(): assert s.secrets[0].title == "Hello" -def test_multi_works_after_doc_has_been_saved(): +def test_multi_works_after_doc_has_been_saved() -> None: c = SimpleCommit() c.full_clean() c.files.append("setup.py") @@ -250,7 +256,7 @@ def test_multi_works_after_doc_has_been_saved(): assert c.to_dict() == {"files": ["setup.py"]} -def test_multi_works_in_nested_after_doc_has_been_serialized(): +def test_multi_works_in_nested_after_doc_has_been_serialized() -> None: # Issue #359 c = DocWithNested(comments=[Comment(title="First!")]) @@ -259,18 +265,18 @@ def test_multi_works_in_nested_after_doc_has_been_serialized(): assert [] == c.comments[0].tags -def test_null_value_for_object(): +def test_null_value_for_object() -> None: d = MyDoc(inner=None) assert d.inner is None -def test_inherited_doc_types_can_override_index(): +def test_inherited_doc_types_can_override_index() -> None: class MyDocDifferentIndex(MySubDoc): class Index: name = "not-default-index" settings = {"number_of_replicas": 0} - aliases = {"a": {}} + aliases: Dict[str, Any] = {"a": {}} analyzers = [analyzer("my_analizer", tokenizer="keyword")] assert MyDocDifferentIndex._index._name == "not-default-index" @@ -297,7 +303,7 @@ class Index: } -def test_to_dict_with_meta(): +def test_to_dict_with_meta() -> None: d = MySubDoc(title="hello") d.meta.routing = "some-parent" @@ -308,28 +314,28 @@ def test_to_dict_with_meta(): } == d.to_dict(True) -def test_to_dict_with_meta_includes_custom_index(): +def test_to_dict_with_meta_includes_custom_index() -> None: d = MySubDoc(title="hello") d.meta.index = "other-index" assert {"_index": "other-index", "_source": {"title": "hello"}} == d.to_dict(True) -def test_to_dict_without_skip_empty_will_include_empty_fields(): +def test_to_dict_without_skip_empty_will_include_empty_fields() -> None: d = MySubDoc(tags=[], title=None, inner={}) assert {} == d.to_dict() assert {"tags": [], "title": None, "inner": {}} == d.to_dict(skip_empty=False) -def test_attribute_can_be_removed(): +def test_attribute_can_be_removed() -> None: d = MyDoc(title="hello") del d.title assert "title" not in d._d_ -def test_doc_type_can_be_correctly_pickled(): +def test_doc_type_can_be_correctly_pickled() -> None: d = DocWithNested( title="Hello World!", comments=[Comment(title="hellp")], meta={"id": 42} ) @@ -344,7 +350,7 @@ def test_doc_type_can_be_correctly_pickled(): assert isinstance(d2.comments[0], Comment) -def test_meta_is_accessible_even_on_empty_doc(): +def test_meta_is_accessible_even_on_empty_doc() -> None: d = MyDoc() d.meta @@ -352,7 +358,7 @@ def test_meta_is_accessible_even_on_empty_doc(): d.meta -def test_meta_field_mapping(): +def test_meta_field_mapping() -> None: class User(Document): username = field.Text() @@ -371,7 +377,7 @@ class Meta: } == User._doc_type.mapping.to_dict() -def test_multi_value_fields(): +def test_multi_value_fields() -> None: class Blog(Document): tags = field.Keyword(multi=True) @@ -382,19 +388,19 @@ class Blog(Document): assert ["search", "python"] == b.tags -def test_docs_with_properties(): +def test_docs_with_properties() -> None: class User(Document): - pwd_hash = field.Text() + pwd_hash: str = field.Text() - def check_password(self, pwd): + def check_password(self, pwd: bytes) -> bool: return md5(pwd).hexdigest() == self.pwd_hash @property - def password(self): + def password(self) -> None: raise AttributeError("readonly") @password.setter - def password(self, pwd): + def password(self, pwd: bytes) -> None: self.pwd_hash = md5(pwd).hexdigest() u = User(pwd_hash=md5(b"secret").hexdigest()) @@ -410,7 +416,7 @@ def password(self, pwd): u.password -def test_nested_can_be_assigned_to(): +def test_nested_can_be_assigned_to() -> None: d1 = DocWithNested(comments=[Comment(title="First!")]) d2 = DocWithNested() @@ -421,13 +427,13 @@ def test_nested_can_be_assigned_to(): assert isinstance(d2.comments[0], Comment) -def test_nested_can_be_none(): +def test_nested_can_be_none() -> None: d = DocWithNested(comments=None, title="Hello World!") assert {"title": "Hello World!"} == d.to_dict() -def test_nested_defaults_to_list_and_can_be_updated(): +def test_nested_defaults_to_list_and_can_be_updated() -> None: md = DocWithNested() assert [] == md.comments @@ -436,7 +442,7 @@ def test_nested_defaults_to_list_and_can_be_updated(): assert {"comments": [{"title": "hello World!"}]} == md.to_dict() -def test_to_dict_is_recursive_and_can_cope_with_multi_values(): +def test_to_dict_is_recursive_and_can_cope_with_multi_values() -> None: md = MyDoc(name=["a", "b", "c"]) md.inner = [MyInner(old_field="of1"), MyInner(old_field="of2")] @@ -448,13 +454,13 @@ def test_to_dict_is_recursive_and_can_cope_with_multi_values(): } == md.to_dict() -def test_to_dict_ignores_empty_collections(): +def test_to_dict_ignores_empty_collections() -> None: md = MySubDoc(name="", address={}, count=0, valid=False, tags=[]) assert {"name": "", "count": 0, "valid": False} == md.to_dict() -def test_declarative_mapping_definition(): +def test_declarative_mapping_definition() -> None: assert issubclass(MyDoc, Document) assert hasattr(MyDoc, "_doc_type") assert { @@ -467,7 +473,7 @@ def test_declarative_mapping_definition(): } == MyDoc._doc_type.mapping.to_dict() -def test_you_can_supply_own_mapping_instance(): +def test_you_can_supply_own_mapping_instance() -> None: class MyD(Document): title = field.Text() @@ -481,7 +487,7 @@ class Meta: } == MyD._doc_type.mapping.to_dict() -def test_document_can_be_created_dynamically(): +def test_document_can_be_created_dynamically() -> None: n = datetime.now() md = MyDoc(title="hello") md.name = "My Fancy Document!" @@ -502,14 +508,14 @@ def test_document_can_be_created_dynamically(): } == md.to_dict() -def test_invalid_date_will_raise_exception(): +def test_invalid_date_will_raise_exception() -> None: md = MyDoc() md.created_at = "not-a-date" with raises(ValidationException): md.full_clean() -def test_document_inheritance(): +def test_document_inheritance() -> None: assert issubclass(MySubDoc, MyDoc) assert issubclass(MySubDoc, Document) assert hasattr(MySubDoc, "_doc_type") @@ -523,7 +529,7 @@ def test_document_inheritance(): } == MySubDoc._doc_type.mapping.to_dict() -def test_child_class_can_override_parent(): +def test_child_class_can_override_parent() -> None: class A(Document): o = field.Object(dynamic=False, properties={"a": field.Text()}) @@ -541,7 +547,7 @@ class B(A): } == B._doc_type.mapping.to_dict() -def test_meta_fields_are_stored_in_meta_and_ignored_by_to_dict(): +def test_meta_fields_are_stored_in_meta_and_ignored_by_to_dict() -> None: md = MySubDoc(meta={"id": 42}, name="My First doc!") md.meta.index = "my-index" @@ -551,7 +557,7 @@ def test_meta_fields_are_stored_in_meta_and_ignored_by_to_dict(): assert {"id": 42, "index": "my-index"} == md.meta.to_dict() -def test_index_inheritance(): +def test_index_inheritance() -> None: assert issubclass(MyMultiSubDoc, MySubDoc) assert issubclass(MyMultiSubDoc, MyDoc2) assert issubclass(MyMultiSubDoc, Document) @@ -568,7 +574,7 @@ def test_index_inheritance(): } == MyMultiSubDoc._doc_type.mapping.to_dict() -def test_meta_fields_can_be_set_directly_in_init(): +def test_meta_fields_can_be_set_directly_in_init() -> None: p = object() md = MyDoc(_id=p, title="Hello World!") @@ -576,27 +582,27 @@ def test_meta_fields_can_be_set_directly_in_init(): @pytest.mark.sync -def test_save_no_index(mock_client): +def test_save_no_index(mock_client: Any) -> None: md = MyDoc() with raises(ValidationException): md.save(using="mock") @pytest.mark.sync -def test_delete_no_index(mock_client): +def test_delete_no_index(mock_client: Any) -> None: md = MyDoc() with raises(ValidationException): md.delete(using="mock") @pytest.mark.sync -def test_update_no_fields(): +def test_update_no_fields() -> None: md = MyDoc() with raises(IllegalOperation): md.update() -def test_search_with_custom_alias_and_index(): +def test_search_with_custom_alias_and_index() -> None: search_object = MyDoc.search( using="staging", index=["custom_index1", "custom_index2"] ) @@ -605,7 +611,7 @@ def test_search_with_custom_alias_and_index(): assert search_object._index == ["custom_index1", "custom_index2"] -def test_from_es_respects_underscored_non_meta_fields(): +def test_from_es_respects_underscored_non_meta_fields() -> None: doc = { "_index": "test-index", "_id": "elasticsearch", @@ -629,7 +635,7 @@ class Index: assert c._tagline == "You know, for search" -def test_nested_and_object_inner_doc(): +def test_nested_and_object_inner_doc() -> None: class MySubDocWithNested(MyDoc): nested_inner = field.Nested(MyInner) @@ -646,7 +652,7 @@ class MySubDocWithNested(MyDoc): } -def test_doc_with_type_hints(): +def test_doc_with_type_hints() -> None: class TypedInnerDoc(InnerDoc): st: M[str] dt: M[Optional[datetime]] @@ -656,16 +662,16 @@ class TypedDoc(Document): st: str dt: Optional[datetime] li: List[int] - ob: Optional[TypedInnerDoc] - ns: Optional[List[TypedInnerDoc]] + ob: TypedInnerDoc + ns: List[TypedInnerDoc] ip: Optional[str] = field.Ip() k1: str = field.Keyword(required=True) k2: M[str] = field.Keyword() k3: str = mapped_field(field.Keyword(), default="foo") - k4: M[Optional[str]] = mapped_field(field.Keyword()) + k4: M[Optional[str]] = mapped_field(field.Keyword()) # type: ignore[misc] s1: Secret = SecretField() s2: M[Secret] = SecretField() - s3: Secret = mapped_field(SecretField()) + s3: Secret = mapped_field(SecretField()) # type: ignore[misc] s4: M[Optional[Secret]] = mapped_field( SecretField(), default_factory=lambda: "foo" ) @@ -707,27 +713,34 @@ class TypedDoc(Document): assert doc.s4 == "foo" with raises(ValidationException) as exc_info: doc.full_clean() - assert set(exc_info.value.args[0].keys()) == {"st", "k1", "k2", "s1", "s2", "s3"} + assert set(exc_info.value.args[0].keys()) == { + "st", + "k1", + "k2", + "ob", + "s1", + "s2", + "s3", + } doc.st = "s" doc.li = [1, 2, 3] doc.k1 = "k1" doc.k2 = "k2" + doc.ob.st = "s" + doc.ob.li = [1] doc.s1 = "s1" doc.s2 = "s2" doc.s3 = "s3" doc.full_clean() - doc.ob = TypedInnerDoc() + doc.ob = TypedInnerDoc(li=[1]) with raises(ValidationException) as exc_info: doc.full_clean() assert set(exc_info.value.args[0].keys()) == {"ob"} assert set(exc_info.value.args[0]["ob"][0].args[0].keys()) == {"st"} doc.ob.st = "s" - doc.ob.li = [1] - doc.full_clean() - doc.ns.append(TypedInnerDoc(li=[1, 2])) with raises(ValidationException) as exc_info: doc.full_clean() @@ -766,7 +779,7 @@ class TypedDoc(Document): assert s.to_dict() == {"sort": ["st", {"dt": {"order": "desc"}}, "ob.st"]} -def test_instrumented_field(): +def test_instrumented_field() -> None: class Child(InnerDoc): st: M[str] diff --git a/tests/_sync/test_faceted_search.py b/tests/_sync/test_faceted_search.py index 079dd146..fa85e99b 100644 --- a/tests/_sync/test_faceted_search.py +++ b/tests/_sync/test_faceted_search.py @@ -28,10 +28,10 @@ class BlogSearch(FacetedSearch): doc_types = ["user", "post"] - fields = ( + fields = [ "title^5", "body", - ) + ] facets = { "category": TermsFacet(field="category.raw"), @@ -39,7 +39,7 @@ class BlogSearch(FacetedSearch): } -def test_query_is_created_properly(): +def test_query_is_created_properly() -> None: bs = BlogSearch("python search") s = bs.build_search() @@ -56,13 +56,13 @@ def test_query_is_created_properly(): }, }, "query": { - "multi_match": {"fields": ("title^5", "body"), "query": "python search"} + "multi_match": {"fields": ["title^5", "body"], "query": "python search"} }, "highlight": {"fields": {"body": {}, "title": {}}}, } == s.to_dict() -def test_query_is_created_properly_with_sort_tuple(): +def test_query_is_created_properly_with_sort_tuple() -> None: bs = BlogSearch("python search", sort=("category", "-title")) s = bs.build_search() @@ -79,14 +79,14 @@ def test_query_is_created_properly_with_sort_tuple(): }, }, "query": { - "multi_match": {"fields": ("title^5", "body"), "query": "python search"} + "multi_match": {"fields": ["title^5", "body"], "query": "python search"} }, "highlight": {"fields": {"body": {}, "title": {}}}, "sort": ["category", {"title": {"order": "desc"}}], } == s.to_dict() -def test_filter_is_applied_to_search_but_not_relevant_facet(): +def test_filter_is_applied_to_search_but_not_relevant_facet() -> None: bs = BlogSearch("python search", filters={"category": "elastic"}) s = bs.build_search() @@ -103,13 +103,13 @@ def test_filter_is_applied_to_search_but_not_relevant_facet(): }, "post_filter": {"terms": {"category.raw": ["elastic"]}}, "query": { - "multi_match": {"fields": ("title^5", "body"), "query": "python search"} + "multi_match": {"fields": ["title^5", "body"], "query": "python search"} }, "highlight": {"fields": {"body": {}, "title": {}}}, } == s.to_dict() -def test_filters_are_applied_to_search_ant_relevant_facets(): +def test_filters_are_applied_to_search_ant_relevant_facets() -> None: bs = BlogSearch( "python search", filters={"category": "elastic", "tags": ["python", "django"]} ) @@ -135,17 +135,17 @@ def test_filters_are_applied_to_search_ant_relevant_facets(): }, }, "query": { - "multi_match": {"fields": ("title^5", "body"), "query": "python search"} + "multi_match": {"fields": ["title^5", "body"], "query": "python search"} }, "post_filter": {"bool": {}}, "highlight": {"fields": {"body": {}, "title": {}}}, } == d -def test_date_histogram_facet_with_1970_01_01_date(): +def test_date_histogram_facet_with_1970_01_01_date() -> None: dhf = DateHistogramFacet() - assert dhf.get_value({"key": None}) == datetime(1970, 1, 1, 0, 0) - assert dhf.get_value({"key": 0}) == datetime(1970, 1, 1, 0, 0) + assert dhf.get_value({"key": None}) == datetime(1970, 1, 1, 0, 0) # type: ignore[arg-type] + assert dhf.get_value({"key": 0}) == datetime(1970, 1, 1, 0, 0) # type: ignore[arg-type] @pytest.mark.parametrize( @@ -175,7 +175,7 @@ def test_date_histogram_facet_with_1970_01_01_date(): ("fixed_interval", "1h"), ], ) -def test_date_histogram_interval_types(interval_type, interval): +def test_date_histogram_interval_types(interval_type: str, interval: str) -> None: dhf = DateHistogramFacet(field="@timestamp", **{interval_type: interval}) assert dhf.get_aggregation().to_dict() == { "date_histogram": { @@ -187,14 +187,14 @@ def test_date_histogram_interval_types(interval_type, interval): dhf.get_value_filter(datetime.now()) -def test_date_histogram_no_interval_keyerror(): +def test_date_histogram_no_interval_keyerror() -> None: dhf = DateHistogramFacet(field="@timestamp") with pytest.raises(KeyError) as e: dhf.get_value_filter(datetime.now()) assert str(e.value) == "'interval'" -def test_params_added_to_search(): +def test_params_added_to_search() -> None: bs = BlogSearch("python search") assert bs._s._params == {} bs.params(routing="42") diff --git a/tests/_sync/test_index.py b/tests/_sync/test_index.py index 4d1238f8..028e70d0 100644 --- a/tests/_sync/test_index.py +++ b/tests/_sync/test_index.py @@ -17,6 +17,7 @@ import string from random import choice +from typing import Any, Dict import pytest from pytest import raises @@ -29,7 +30,7 @@ class Post(Document): published_from = Date() -def test_multiple_doc_types_will_combine_mappings(): +def test_multiple_doc_types_will_combine_mappings() -> None: class User(Document): username = Text() @@ -47,16 +48,16 @@ class User(Document): } == i.to_dict() -def test_search_is_limited_to_index_name(): +def test_search_is_limited_to_index_name() -> None: i = Index("my-index") s = i.search() assert s._index == ["my-index"] -def test_cloned_index_has_copied_settings_and_using(): +def test_cloned_index_has_copied_settings_and_using() -> None: client = object() - i = Index("my-index", using=client) + i = Index("my-index", using=client) # type: ignore[arg-type] i.settings(number_of_shards=1) i2 = i.clone("my-other-index") @@ -67,13 +68,13 @@ def test_cloned_index_has_copied_settings_and_using(): assert i._settings is not i2._settings -def test_cloned_index_has_analysis_attribute(): +def test_cloned_index_has_analysis_attribute() -> None: """ Regression test for Issue #582 in which `AsyncIndex.clone()` was not copying over the `_analysis` attribute. """ client = object() - i = Index("my-index", using=client) + i = Index("my-index", using=client) # type: ignore[arg-type] random_analyzer_name = "".join(choice(string.ascii_letters) for _ in range(100)) random_analyzer = analyzer( @@ -87,7 +88,7 @@ def test_cloned_index_has_analysis_attribute(): assert i.to_dict()["settings"]["analysis"] == i2.to_dict()["settings"]["analysis"] -def test_settings_are_saved(): +def test_settings_are_saved() -> None: i = Index("i") i.settings(number_of_replicas=0) i.settings(number_of_shards=1) @@ -95,7 +96,7 @@ def test_settings_are_saved(): assert {"settings": {"number_of_shards": 1, "number_of_replicas": 0}} == i.to_dict() -def test_registered_doc_type_included_in_to_dict(): +def test_registered_doc_type_included_in_to_dict() -> None: i = Index("i", using="alias") i.document(Post) @@ -109,7 +110,7 @@ def test_registered_doc_type_included_in_to_dict(): } == i.to_dict() -def test_registered_doc_type_included_in_search(): +def test_registered_doc_type_included_in_search() -> None: i = Index("i", using="alias") i.document(Post) @@ -118,9 +119,9 @@ def test_registered_doc_type_included_in_search(): assert s._doc_type == [Post] -def test_aliases_add_to_object(): +def test_aliases_add_to_object() -> None: random_alias = "".join(choice(string.ascii_letters) for _ in range(100)) - alias_dict = {random_alias: {}} + alias_dict: Dict[str, Any] = {random_alias: {}} index = Index("i", using="alias") index.aliases(**alias_dict) @@ -128,9 +129,9 @@ def test_aliases_add_to_object(): assert index._aliases == alias_dict -def test_aliases_returned_from_to_dict(): +def test_aliases_returned_from_to_dict() -> None: random_alias = "".join(choice(string.ascii_letters) for _ in range(100)) - alias_dict = {random_alias: {}} + alias_dict: Dict[str, Any] = {random_alias: {}} index = Index("i", using="alias") index.aliases(**alias_dict) @@ -138,7 +139,7 @@ def test_aliases_returned_from_to_dict(): assert index._aliases == index.to_dict()["aliases"] == alias_dict -def test_analyzers_added_to_object(): +def test_analyzers_added_to_object() -> None: random_analyzer_name = "".join(choice(string.ascii_letters) for _ in range(100)) random_analyzer = analyzer( random_analyzer_name, tokenizer="standard", filter="standard" @@ -154,7 +155,7 @@ def test_analyzers_added_to_object(): } -def test_analyzers_returned_from_to_dict(): +def test_analyzers_returned_from_to_dict() -> None: random_analyzer_name = "".join(choice(string.ascii_letters) for _ in range(100)) random_analyzer = analyzer( random_analyzer_name, tokenizer="standard", filter="standard" @@ -167,7 +168,7 @@ def test_analyzers_returned_from_to_dict(): ] == {"filter": ["standard"], "type": "custom", "tokenizer": "standard"} -def test_conflicting_analyzer_raises_error(): +def test_conflicting_analyzer_raises_error() -> None: i = Index("i") i.analyzer("my_analyzer", tokenizer="whitespace", filter=["lowercase", "stop"]) @@ -175,7 +176,7 @@ def test_conflicting_analyzer_raises_error(): i.analyzer("my_analyzer", tokenizer="keyword", filter=["lowercase", "stop"]) -def test_index_template_can_have_order(): +def test_index_template_can_have_order() -> None: i = Index("i-*") it = i.as_template("i", order=2) @@ -183,7 +184,7 @@ def test_index_template_can_have_order(): @pytest.mark.sync -def test_index_template_save_result(mock_client): +def test_index_template_save_result(mock_client: Any) -> None: it = IndexTemplate("test-template", "test-*") assert it.save(using="mock") == mock_client.indices.put_template() diff --git a/tests/_sync/test_mapping.py b/tests/_sync/test_mapping.py index 500b5dde..8ba2f2e6 100644 --- a/tests/_sync/test_mapping.py +++ b/tests/_sync/test_mapping.py @@ -20,7 +20,7 @@ from elasticsearch_dsl import Keyword, Mapping, Nested, Text, analysis -def test_mapping_can_has_fields(): +def test_mapping_can_has_fields() -> None: m = Mapping() m.field("name", "text").field("tags", "keyword") @@ -29,7 +29,7 @@ def test_mapping_can_has_fields(): } == m.to_dict() -def test_mapping_update_is_recursive(): +def test_mapping_update_is_recursive() -> None: m1 = Mapping() m1.field("title", "text") m1.field("author", "object") @@ -62,7 +62,7 @@ def test_mapping_update_is_recursive(): } == m1.to_dict() -def test_properties_can_iterate_over_all_the_fields(): +def test_properties_can_iterate_over_all_the_fields() -> None: m = Mapping() m.field("f1", "text", test_attr="f1", fields={"f2": Keyword(test_attr="f2")}) m.field("f3", Nested(test_attr="f3", properties={"f4": Text(test_attr="f4")})) @@ -72,7 +72,7 @@ def test_properties_can_iterate_over_all_the_fields(): } -def test_mapping_can_collect_all_analyzers_and_normalizers(): +def test_mapping_can_collect_all_analyzers_and_normalizers() -> None: a1 = analysis.analyzer( "my_analyzer1", tokenizer="keyword", @@ -145,7 +145,7 @@ def test_mapping_can_collect_all_analyzers_and_normalizers(): assert json.loads(json.dumps(m.to_dict())) == m.to_dict() -def test_mapping_can_collect_multiple_analyzers(): +def test_mapping_can_collect_multiple_analyzers() -> None: a1 = analysis.analyzer( "my_analyzer1", tokenizer="keyword", @@ -191,7 +191,7 @@ def test_mapping_can_collect_multiple_analyzers(): } == m._collect_analysis() -def test_even_non_custom_analyzers_can_have_params(): +def test_even_non_custom_analyzers_can_have_params() -> None: a1 = analysis.analyzer("whitespace", type="pattern", pattern=r"\\s+") m = Mapping() m.field("title", "text", analyzer=a1) @@ -201,14 +201,14 @@ def test_even_non_custom_analyzers_can_have_params(): } == m._collect_analysis() -def test_resolve_field_can_resolve_multifields(): +def test_resolve_field_can_resolve_multifields() -> None: m = Mapping() m.field("title", "text", fields={"keyword": Keyword()}) assert isinstance(m.resolve_field("title.keyword"), Keyword) -def test_resolve_nested(): +def test_resolve_nested() -> None: m = Mapping() m.field("n1", "nested", properties={"n2": Nested(properties={"k1": Keyword()})}) m.field("k2", "keyword") diff --git a/tests/_sync/test_search.py b/tests/_sync/test_search.py index 61dadcb0..638c614d 100644 --- a/tests/_sync/test_search.py +++ b/tests/_sync/test_search.py @@ -16,61 +16,62 @@ # under the License. from copy import deepcopy +from typing import Any import pytest from pytest import raises -from elasticsearch_dsl import A, Document, EmptySearch, Q, Search, query +from elasticsearch_dsl import Document, EmptySearch, Q, Search, query from elasticsearch_dsl.exceptions import IllegalOperation -def test_expand__to_dot_is_respected(): +def test_expand__to_dot_is_respected() -> None: s = Search().query("match", a__b=42, _expand__to_dot=False) assert {"query": {"match": {"a__b": 42}}} == s.to_dict() @pytest.mark.sync -def test_execute_uses_cache(): +def test_execute_uses_cache() -> None: s = Search() r = object() - s._response = r + s._response = r # type: ignore[assignment] assert r is s.execute() @pytest.mark.sync -def test_cache_can_be_ignored(mock_client): +def test_cache_can_be_ignored(mock_client: Any) -> None: s = Search(using="mock") r = object() - s._response = r + s._response = r # type: ignore[assignment] s.execute(ignore_cache=True) mock_client.search.assert_called_once_with(index=None, body={}) @pytest.mark.sync -def test_iter_iterates_over_hits(): +def test_iter_iterates_over_hits() -> None: s = Search() - s._response = [1, 2, 3] + s._response = [1, 2, 3] # type: ignore[assignment] assert [1, 2, 3] == [hit for hit in s] -def test_cache_isnt_cloned(): +def test_cache_isnt_cloned() -> None: s = Search() - s._response = object() + s._response = object() # type: ignore[assignment] assert not hasattr(s._clone(), "_response") -def test_search_starts_with_no_query(): +def test_search_starts_with_no_query() -> None: s = Search() assert s.query._proxied is None -def test_search_query_combines_query(): +def test_search_query_combines_query() -> None: s = Search() s2 = s.query("match", f=42) @@ -82,19 +83,19 @@ def test_search_query_combines_query(): assert s3.query._proxied == query.Bool(must=[query.Match(f=42), query.Match(f=43)]) -def test_query_can_be_assigned_to(): +def test_query_can_be_assigned_to() -> None: s = Search() q = Q("match", title="python") - s.query = q + s.query = q # type: ignore assert s.query._proxied is q -def test_query_can_be_wrapped(): +def test_query_can_be_wrapped() -> None: s = Search().query("match", title="python") - s.query = Q("function_score", query=s.query, field_value_factor={"field": "rating"}) + s.query = Q("function_score", query=s.query, field_value_factor={"field": "rating"}) # type: ignore assert { "query": { @@ -106,29 +107,29 @@ def test_query_can_be_wrapped(): } == s.to_dict() -def test_using(): +def test_using() -> None: o = object() o2 = object() s = Search(using=o) assert s._using is o - s2 = s.using(o2) + s2 = s.using(o2) # type: ignore[arg-type] assert s._using is o assert s2._using is o2 -def test_methods_are_proxied_to_the_query(): +def test_methods_are_proxied_to_the_query() -> None: s = Search().query("match_all") assert s.query.to_dict() == {"match_all": {}} -def test_query_always_returns_search(): +def test_query_always_returns_search() -> None: s = Search() assert isinstance(s.query("match", f=42), Search) -def test_source_copied_on_clone(): +def test_source_copied_on_clone() -> None: s = Search().source(False) assert s._clone()._source == s._source assert s._clone()._source is False @@ -142,7 +143,7 @@ def test_source_copied_on_clone(): assert s3._clone()._source == ["some", "fields"] -def test_copy_clones(): +def test_copy_clones() -> None: from copy import copy s1 = Search().source(["some", "fields"]) @@ -152,7 +153,7 @@ def test_copy_clones(): assert s1 is not s2 -def test_aggs_allow_two_metric(): +def test_aggs_allow_two_metric() -> None: s = Search() s.aggs.metric("a", "max", field="a").metric("b", "max", field="b") @@ -162,7 +163,7 @@ def test_aggs_allow_two_metric(): } -def test_aggs_get_copied_on_change(): +def test_aggs_get_copied_on_change() -> None: s = Search().query("match_all") s.aggs.bucket("per_tag", "terms", field="f").metric( "max_score", "max", field="score" @@ -175,7 +176,7 @@ def test_aggs_get_copied_on_change(): s4 = s3._clone() s4.aggs.metric("max_score", "max", field="score") - d = { + d: Any = { "query": {"match_all": {}}, "aggs": { "per_tag": { @@ -194,7 +195,7 @@ def test_aggs_get_copied_on_change(): assert d == s4.to_dict() -def test_search_index(): +def test_search_index() -> None: s = Search(index="i") assert s._index == ["i"] s = s.index("i2") @@ -225,7 +226,7 @@ def test_search_index(): assert s2._index == ["i", "i2", "i3", "i4", "i5"] -def test_doc_type_document_class(): +def test_doc_type_document_class() -> None: class MyDocument(Document): pass @@ -238,15 +239,15 @@ class MyDocument(Document): assert s._doc_type_map == {} -def test_knn(): +def test_knn() -> None: s = Search() with raises(TypeError): - s.knn() + s.knn() # type: ignore[call-arg] with raises(TypeError): - s.knn("field") + s.knn("field") # type: ignore[call-arg] with raises(TypeError): - s.knn("field", 5) + s.knn("field", 5) # type: ignore[call-arg] with raises(ValueError): s.knn("field", 5, 100) with raises(ValueError): @@ -294,7 +295,7 @@ def test_knn(): } == s.to_dict() -def test_rank(): +def test_rank() -> None: s = Search() s.rank(rrf=False) assert {} == s.to_dict() @@ -306,7 +307,7 @@ def test_rank(): assert {"rank": {"rrf": {"window_size": 50, "rank_constant": 20}}} == s.to_dict() -def test_sort(): +def test_sort() -> None: s = Search() s = s.sort("fielda", "-fieldb") @@ -318,7 +319,7 @@ def test_sort(): assert Search().to_dict() == s.to_dict() -def test_sort_by_score(): +def test_sort_by_score() -> None: s = Search() s = s.sort("_score") assert {"sort": ["_score"]} == s.to_dict() @@ -328,7 +329,7 @@ def test_sort_by_score(): s.sort("-_score") -def test_collapse(): +def test_collapse() -> None: s = Search() inner_hits = {"name": "most_recent", "size": 5, "sort": [{"@timestamp": "desc"}]} @@ -360,7 +361,7 @@ def test_collapse(): assert Search().to_dict() == s.to_dict() -def test_slice(): +def test_slice() -> None: s = Search() assert {"from": 3, "size": 7} == s[3:10].to_dict() assert {"size": 5} == s[:5].to_dict() @@ -383,7 +384,7 @@ def test_slice(): s[-3:-2] -def test_index(): +def test_index() -> None: s = Search() assert {"from": 3, "size": 1} == s[3].to_dict() assert {"from": 3, "size": 1} == s[3][0].to_dict() @@ -393,7 +394,7 @@ def test_index(): s[-3] -def test_search_to_dict(): +def test_search_to_dict() -> None: s = Search() assert {} == s.to_dict() @@ -422,7 +423,7 @@ def test_search_to_dict(): assert {"size": 5, "from": 42} == s.to_dict() -def test_complex_example(): +def test_complex_example() -> None: s = Search() s = ( s.query("match", title="python") @@ -475,7 +476,7 @@ def test_complex_example(): } == s.to_dict() -def test_reverse(): +def test_reverse() -> None: d = { "query": { "filtered": { @@ -525,14 +526,14 @@ def test_reverse(): assert d == s.to_dict() -def test_from_dict_doesnt_need_query(): +def test_from_dict_doesnt_need_query() -> None: s = Search.from_dict({"size": 5}) assert {"size": 5} == s.to_dict() @pytest.mark.sync -def test_params_being_passed_to_search(mock_client): +def test_params_being_passed_to_search(mock_client: Any) -> None: s = Search(using="mock") s = s.params(routing="42") s.execute() @@ -540,7 +541,7 @@ def test_params_being_passed_to_search(mock_client): mock_client.search.assert_called_once_with(index=None, body={}, routing="42") -def test_source(): +def test_source() -> None: assert {} == Search().source().to_dict() assert { @@ -554,7 +555,7 @@ def test_source(): ).source(["f1", "f2"]).to_dict() -def test_source_on_clone(): +def test_source_on_clone() -> None: assert { "_source": {"includes": ["foo.bar.*"], "excludes": ["foo.one"]}, "query": {"bool": {"filter": [{"term": {"title": "python"}}]}}, @@ -567,7 +568,7 @@ def test_source_on_clone(): } == Search().source(False).filter("term", title="python").to_dict() -def test_source_on_clear(): +def test_source_on_clear() -> None: assert ( {} == Search() @@ -577,7 +578,7 @@ def test_source_on_clear(): ) -def test_suggest_accepts_global_text(): +def test_suggest_accepts_global_text() -> None: s = Search.from_dict( { "suggest": { @@ -599,7 +600,7 @@ def test_suggest_accepts_global_text(): } == s.to_dict() -def test_suggest(): +def test_suggest() -> None: s = Search() s = s.suggest("my_suggestion", "pyhton", term={"field": "title"}) @@ -608,7 +609,7 @@ def test_suggest(): } == s.to_dict() -def test_exclude(): +def test_exclude() -> None: s = Search() s = s.exclude("match", title="python") @@ -622,7 +623,7 @@ def test_exclude(): @pytest.mark.sync -def test_delete_by_query(mock_client): +def test_delete_by_query(mock_client: Any) -> None: s = Search(using="mock", index="i").query("match", lang="java") s.delete() @@ -631,7 +632,7 @@ def test_delete_by_query(mock_client): ) -def test_update_from_dict(): +def test_update_from_dict() -> None: s = Search() s.update_from_dict({"indices_boost": [{"important-documents": 2}]}) s.update_from_dict({"_source": ["id", "name"]}) @@ -644,7 +645,7 @@ def test_update_from_dict(): } == s.to_dict() -def test_rescore_query_to_dict(): +def test_rescore_query_to_dict() -> None: s = Search(index="index-name") positive_query = Q( @@ -707,10 +708,10 @@ def test_rescore_query_to_dict(): @pytest.mark.sync -def test_empty_search(): +def test_empty_search() -> None: s = EmptySearch(index="index-name") s = s.query("match", lang="java") - s.aggs.bucket("versions", A("terms", field="version")) + s.aggs.bucket("versions", "terms", field="version") assert s.count() == 0 assert [hit for hit in s] == [] @@ -718,7 +719,7 @@ def test_empty_search(): s.delete() # should not error -def test_suggest_completion(): +def test_suggest_completion() -> None: s = Search() s = s.suggest("my_suggestion", "pyhton", completion={"field": "title"}) @@ -729,7 +730,7 @@ def test_suggest_completion(): } == s.to_dict() -def test_suggest_regex_query(): +def test_suggest_regex_query() -> None: s = Search() s = s.suggest("my_suggestion", regex="py[thon|py]", completion={"field": "title"}) @@ -740,19 +741,19 @@ def test_suggest_regex_query(): } == s.to_dict() -def test_suggest_must_pass_text_or_regex(): +def test_suggest_must_pass_text_or_regex() -> None: s = Search() with raises(ValueError): s.suggest("my_suggestion") -def test_suggest_can_only_pass_text_or_regex(): +def test_suggest_can_only_pass_text_or_regex() -> None: s = Search() with raises(ValueError): s.suggest("my_suggestion", text="python", regex="py[hton|py]") -def test_suggest_regex_must_be_wtih_completion(): +def test_suggest_regex_must_be_wtih_completion() -> None: s = Search() with raises(ValueError): s.suggest("my_suggestion", regex="py[thon|py]") diff --git a/tests/_sync/test_update_by_query.py b/tests/_sync/test_update_by_query.py index 6aa67c2d..68d89c50 100644 --- a/tests/_sync/test_update_by_query.py +++ b/tests/_sync/test_update_by_query.py @@ -16,20 +16,22 @@ # under the License. from copy import deepcopy +from typing import Any import pytest from elasticsearch_dsl import Q, UpdateByQuery from elasticsearch_dsl.response import UpdateByQueryResponse +from elasticsearch_dsl.search_base import SearchBase -def test_ubq_starts_with_no_query(): +def test_ubq_starts_with_no_query() -> None: ubq = UpdateByQuery() assert ubq.query._proxied is None -def test_ubq_to_dict(): +def test_ubq_to_dict() -> None: ubq = UpdateByQuery() assert {} == ubq.to_dict() @@ -45,7 +47,7 @@ def test_ubq_to_dict(): assert {"extra_q": {"term": {"category": "conference"}}} == ubq.to_dict() -def test_complex_example(): +def test_complex_example() -> None: ubq = UpdateByQuery() ubq = ( ubq.query("match", title="python") @@ -83,7 +85,7 @@ def test_complex_example(): } == ubq.to_dict() -def test_exclude(): +def test_exclude() -> None: ubq = UpdateByQuery() ubq = ubq.exclude("match", title="python") @@ -96,7 +98,7 @@ def test_exclude(): } == ubq.to_dict() -def test_reverse(): +def test_reverse() -> None: d = { "query": { "filtered": { @@ -132,14 +134,14 @@ def test_reverse(): assert d == ubq.to_dict() -def test_from_dict_doesnt_need_query(): +def test_from_dict_doesnt_need_query() -> None: ubq = UpdateByQuery.from_dict({"script": {"source": "test"}}) assert {"script": {"source": "test"}} == ubq.to_dict() @pytest.mark.sync -def test_params_being_passed_to_search(mock_client): +def test_params_being_passed_to_search(mock_client: Any) -> None: ubq = UpdateByQuery(using="mock", index="i") ubq = ubq.params(routing="42") ubq.execute() @@ -147,7 +149,7 @@ def test_params_being_passed_to_search(mock_client): mock_client.update_by_query.assert_called_once_with(index=["i"], routing="42") -def test_overwrite_script(): +def test_overwrite_script() -> None: ubq = UpdateByQuery() ubq = ubq.script( source="ctx._source.likes += params.f", lang="painless", params={"f": 3} @@ -163,12 +165,12 @@ def test_overwrite_script(): assert {"script": {"source": "ctx._source.likes++"}} == ubq.to_dict() -def test_update_by_query_response_success(): - ubqr = UpdateByQueryResponse({}, {"timed_out": False, "failures": []}) +def test_update_by_query_response_success() -> None: + ubqr = UpdateByQueryResponse(SearchBase(), {"timed_out": False, "failures": []}) assert ubqr.success() - ubqr = UpdateByQueryResponse({}, {"timed_out": True, "failures": []}) + ubqr = UpdateByQueryResponse(SearchBase(), {"timed_out": True, "failures": []}) assert not ubqr.success() - ubqr = UpdateByQueryResponse({}, {"timed_out": False, "failures": [{}]}) + ubqr = UpdateByQueryResponse(SearchBase(), {"timed_out": False, "failures": [{}]}) assert not ubqr.success() diff --git a/tests/async_sleep.py b/tests/async_sleep.py index ce700003..ce5ced1c 100644 --- a/tests/async_sleep.py +++ b/tests/async_sleep.py @@ -16,8 +16,9 @@ # under the License. import asyncio +from typing import Union -async def sleep(secs): +async def sleep(secs: Union[int, float]) -> None: """Tests can use this function to sleep.""" await asyncio.sleep(secs) diff --git a/tests/conftest.py b/tests/conftest.py index f1751203..1a07670e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,6 +21,7 @@ import re import time from datetime import datetime +from typing import Any, AsyncGenerator, Dict, Generator, Tuple, cast from unittest import SkipTest, TestCase from unittest.mock import AsyncMock, Mock @@ -31,6 +32,7 @@ from elasticsearch.helpers import bulk from pytest import fixture, skip +from elasticsearch_dsl import Search from elasticsearch_dsl.async_connections import add_connection as add_async_connection from elasticsearch_dsl.async_connections import connections as async_connections from elasticsearch_dsl.connections import add_connection, connections @@ -51,16 +53,12 @@ ELASTICSEARCH_URL = "http://localhost:9200" -def get_test_client(wait=True, **kwargs): +def get_test_client(wait: bool = True, **kwargs: Any) -> Elasticsearch: # construct kwargs from the environment - kw = {"request_timeout": 30} + kw: Dict[str, Any] = {"request_timeout": 30} if "PYTHON_CONNECTION_CLASS" in os.environ: - from elasticsearch import connection - - kw["connection_class"] = getattr( - connection, os.environ["PYTHON_CONNECTION_CLASS"] - ) + kw["node_class"] = os.environ["PYTHON_CONNECTION_CLASS"] kw.update(kwargs) client = Elasticsearch(ELASTICSEARCH_URL, **kw) @@ -78,16 +76,12 @@ def get_test_client(wait=True, **kwargs): raise SkipTest("Elasticsearch failed to start.") -async def get_async_test_client(wait=True, **kwargs): +async def get_async_test_client(wait: bool = True, **kwargs: Any) -> AsyncElasticsearch: # construct kwargs from the environment - kw = {"request_timeout": 30} + kw: Dict[str, Any] = {"request_timeout": 30} if "PYTHON_CONNECTION_CLASS" in os.environ: - from elasticsearch import connection - - kw["connection_class"] = getattr( - connection, os.environ["PYTHON_CONNECTION_CLASS"] - ) + kw["node_class"] = os.environ["PYTHON_CONNECTION_CLASS"] kw.update(kwargs) client = AsyncElasticsearch(ELASTICSEARCH_URL, **kw) @@ -107,35 +101,35 @@ async def get_async_test_client(wait=True, **kwargs): class ElasticsearchTestCase(TestCase): + client: Elasticsearch + @staticmethod - def _get_client(): + def _get_client() -> Elasticsearch: return get_test_client() @classmethod - def setup_class(cls): + def setup_class(cls) -> None: cls.client = cls._get_client() - def teardown_method(self, _): + def teardown_method(self, _: Any) -> None: # Hidden indices expanded in wildcards in ES 7.7 expand_wildcards = ["open", "closed"] if self.es_version() >= (7, 7): expand_wildcards.append("hidden") self.client.indices.delete_data_stream( - name="*", ignore=404, expand_wildcards=expand_wildcards - ) - self.client.indices.delete( - index="*", ignore=404, expand_wildcards=expand_wildcards + name="*", expand_wildcards=expand_wildcards ) - self.client.indices.delete_template(name="*", ignore=404) + self.client.indices.delete(index="*", expand_wildcards=expand_wildcards) + self.client.indices.delete_template(name="*") - def es_version(self): + def es_version(self) -> Tuple[int, ...]: if not hasattr(self, "_es_version"): - self._es_version = _get_version(client.info()["version"]["number"]) + self._es_version = _get_version(self.client.info()["version"]["number"]) return self._es_version -def _get_version(version_string): +def _get_version(version_string: str) -> Tuple[int, ...]: if "." not in version_string: return () version = version_string.strip().split(".") @@ -143,7 +137,7 @@ def _get_version(version_string): @fixture(scope="session") -def client(): +def client() -> Elasticsearch: try: connection = get_test_client(wait="WAIT_FOR_ES" in os.environ) add_connection("default", connection) @@ -153,7 +147,7 @@ def client(): @pytest_asyncio.fixture -async def async_client(): +async def async_client() -> AsyncGenerator[AsyncElasticsearch, None]: try: connection = await get_async_test_client(wait="WAIT_FOR_ES" in os.environ) add_async_connection("default", connection) @@ -164,17 +158,16 @@ async def async_client(): @fixture(scope="session") -def es_version(client): +def es_version(client: Elasticsearch) -> Generator[Tuple[int, ...], None, None]: info = client.info() - print(info) yield tuple( int(x) - for x in re.match(r"^([0-9.]+)", info["version"]["number"]).group(1).split(".") + for x in re.match(r"^([0-9.]+)", info["version"]["number"]).group(1).split(".") # type: ignore ) @fixture -def write_client(client): +def write_client(client: Elasticsearch) -> Generator[Elasticsearch, None, None]: yield client for index_name in client.indices.get(index="test-*", expand_wildcards="all"): client.indices.delete(index=index_name) @@ -182,24 +175,30 @@ def write_client(client): @pytest_asyncio.fixture -async def async_write_client(write_client, async_client): +async def async_write_client( + write_client: Elasticsearch, async_client: AsyncElasticsearch +) -> AsyncGenerator[AsyncElasticsearch, None]: yield async_client @fixture -def mock_client(dummy_response): +def mock_client( + dummy_response: ObjectApiResponse[Any], +) -> Generator[Elasticsearch, None, None]: client = Mock() client.search.return_value = dummy_response client.update_by_query.return_value = dummy_response add_connection("mock", client) yield client - connections._conn = {} + connections._conns = {} connections._kwargs = {} @fixture -def async_mock_client(dummy_response): +def async_mock_client( + dummy_response: ObjectApiResponse[Any], +) -> Generator[Elasticsearch, None, None]: client = Mock() client.search = AsyncMock(return_value=dummy_response) client.indices = AsyncMock() @@ -208,12 +207,12 @@ def async_mock_client(dummy_response): add_async_connection("mock", client) yield client - async_connections._conn = {} + async_connections._conns = {} async_connections._kwargs = {} @fixture(scope="session") -def data_client(client): +def data_client(client: Elasticsearch) -> Generator[Elasticsearch, None, None]: # create mappings create_git_index(client, "git") create_flat_git_index(client, "flat-git") @@ -226,12 +225,14 @@ def data_client(client): @pytest_asyncio.fixture -async def async_data_client(data_client, async_client): +async def async_data_client( + data_client: Elasticsearch, async_client: AsyncElasticsearch +) -> AsyncGenerator[AsyncElasticsearch, None]: yield async_client @fixture -def dummy_response(): +def dummy_response() -> ObjectApiResponse[Any]: return ObjectApiResponse( meta=None, body={ @@ -287,9 +288,7 @@ def dummy_response(): @fixture -def aggs_search(): - from elasticsearch_dsl import Search - +def aggs_search() -> Search: s = Search(index="flat-git") s.aggs.bucket("popular_files", "terms", field="files", size=2).metric( "line_stats", "stats", field="stats.lines" @@ -302,7 +301,7 @@ def aggs_search(): @fixture -def aggs_data(): +def aggs_data() -> Dict[str, Any]: return { "took": 4, "timed_out": False, @@ -437,7 +436,7 @@ def aggs_data(): } -def make_pr(pr_module): +def make_pr(pr_module: Any) -> Any: return pr_module.PullRequest( _id=42, comments=[ @@ -458,23 +457,25 @@ def make_pr(pr_module): @fixture -def pull_request(write_client): +def pull_request(write_client: Elasticsearch) -> sync_document.PullRequest: sync_document.PullRequest.init() - pr = make_pr(sync_document) + pr = cast(sync_document.PullRequest, make_pr(sync_document)) pr.save(refresh=True) return pr @pytest_asyncio.fixture -async def async_pull_request(async_write_client): +async def async_pull_request( + async_write_client: AsyncElasticsearch, +) -> async_document.PullRequest: await async_document.PullRequest.init() - pr = make_pr(async_document) + pr = cast(async_document.PullRequest, make_pr(async_document)) await pr.save(refresh=True) return pr @fixture -def setup_ubq_tests(client) -> str: +def setup_ubq_tests(client: Elasticsearch) -> str: index = "test-git" create_git_index(client, index) bulk(client, TEST_GIT_DATA, raise_on_error=True, refresh=True) diff --git a/tests/sleep.py b/tests/sleep.py index b45d71b6..83009566 100644 --- a/tests/sleep.py +++ b/tests/sleep.py @@ -16,8 +16,9 @@ # under the License. import time +from typing import Union -def sleep(secs): +def sleep(secs: Union[int, float]) -> None: """Tests can use this function to sleep.""" time.sleep(secs) diff --git a/tests/test_integration/_async/test_analysis.py b/tests/test_integration/_async/test_analysis.py index ef157c51..be937dd6 100644 --- a/tests/test_integration/_async/test_analysis.py +++ b/tests/test_integration/_async/test_analysis.py @@ -16,12 +16,15 @@ # under the License. import pytest +from elasticsearch import AsyncElasticsearch from elasticsearch_dsl import analyzer, token_filter, tokenizer @pytest.mark.asyncio -async def test_simulate_with_just__builtin_tokenizer(async_client): +async def test_simulate_with_just__builtin_tokenizer( + async_client: AsyncElasticsearch, +) -> None: a = analyzer("my-analyzer", tokenizer="keyword") tokens = (await a.async_simulate("Hello World!", using=async_client)).tokens @@ -30,7 +33,7 @@ async def test_simulate_with_just__builtin_tokenizer(async_client): @pytest.mark.asyncio -async def test_simulate_complex(async_client): +async def test_simulate_complex(async_client: AsyncElasticsearch) -> None: a = analyzer( "my-analyzer", tokenizer=tokenizer("split_words", "simple_pattern_split", pattern=":"), @@ -44,7 +47,7 @@ async def test_simulate_complex(async_client): @pytest.mark.asyncio -async def test_simulate_builtin(async_client): +async def test_simulate_builtin(async_client: AsyncElasticsearch) -> None: a = analyzer("my-analyzer", "english") tokens = (await a.async_simulate("fixes running")).tokens diff --git a/tests/test_integration/_async/test_document.py b/tests/test_integration/_async/test_document.py index 10677fc5..36173eb6 100644 --- a/tests/test_integration/_async/test_document.py +++ b/tests/test_integration/_async/test_document.py @@ -15,16 +15,24 @@ # specific language governing permissions and limitations # under the License. +# this file creates several documents using bad or no types because +# these are still supported and should be kept functional in spite +# of not having appropriate type hints. For that reason the comment +# below disables many mypy checks that fails as a result of this. +# mypy: disable-error-code="assignment, index, arg-type, call-arg, operator, comparison-overlap, attr-defined" + from datetime import datetime from ipaddress import ip_address +from typing import Any import pytest -from elasticsearch import ConflictError, NotFoundError +from elasticsearch import AsyncElasticsearch, ConflictError, NotFoundError from pytest import raises from pytz import timezone from elasticsearch_dsl import ( AsyncDocument, + AsyncSearch, Binary, Boolean, Date, @@ -67,7 +75,7 @@ class Repository(AsyncDocument): tags = Keyword() @classmethod - def search(cls): + def search(cls) -> AsyncSearch["Repository"]: # type: ignore[override] return super().search().filter("term", commit_repo="repo") class Index: @@ -128,7 +136,7 @@ class Index: @pytest.mark.asyncio -async def test_serialization(async_write_client): +async def test_serialization(async_write_client: AsyncElasticsearch) -> None: await SerializationDoc.init() await async_write_client.index( index="test-serialization", @@ -142,6 +150,7 @@ async def test_serialization(async_write_client): }, ) sd = await SerializationDoc.get(id=42) + assert sd is not None assert sd.i == [1, 2, 3, None] assert sd.b == [True, False, True, False, None] @@ -159,7 +168,7 @@ async def test_serialization(async_write_client): @pytest.mark.asyncio -async def test_nested_inner_hits_are_wrapped_properly(async_pull_request): +async def test_nested_inner_hits_are_wrapped_properly(async_pull_request: Any) -> None: history_query = Q( "nested", path="comments.history", @@ -188,7 +197,9 @@ async def test_nested_inner_hits_are_wrapped_properly(async_pull_request): @pytest.mark.asyncio -async def test_nested_inner_hits_are_deserialized_properly(async_pull_request): +async def test_nested_inner_hits_are_deserialized_properly( + async_pull_request: Any, +) -> None: s = PullRequest.search().query( "nested", inner_hits={}, @@ -204,7 +215,7 @@ async def test_nested_inner_hits_are_deserialized_properly(async_pull_request): @pytest.mark.asyncio -async def test_nested_top_hits_are_wrapped_properly(async_pull_request): +async def test_nested_top_hits_are_wrapped_properly(async_pull_request: Any) -> None: s = PullRequest.search() s.aggs.bucket("comments", "nested", path="comments").metric( "hits", "top_hits", size=1 @@ -217,7 +228,7 @@ async def test_nested_top_hits_are_wrapped_properly(async_pull_request): @pytest.mark.asyncio -async def test_update_object_field(async_write_client): +async def test_update_object_field(async_write_client: AsyncElasticsearch) -> None: await Wiki.init() w = Wiki( owner=User(name="Honza Kral"), @@ -238,7 +249,7 @@ async def test_update_object_field(async_write_client): @pytest.mark.asyncio -async def test_update_script(async_write_client): +async def test_update_script(async_write_client: AsyncElasticsearch) -> None: await Wiki.init() w = Wiki(owner=User(name="Honza Kral"), _id="elasticsearch-py", views=42) await w.save() @@ -249,7 +260,7 @@ async def test_update_script(async_write_client): @pytest.mark.asyncio -async def test_update_script_with_dict(async_write_client): +async def test_update_script_with_dict(async_write_client: AsyncElasticsearch) -> None: await Wiki.init() w = Wiki(owner=User(name="Honza Kral"), _id="elasticsearch-py", views=42) await w.save() @@ -267,13 +278,16 @@ async def test_update_script_with_dict(async_write_client): @pytest.mark.asyncio -async def test_update_retry_on_conflict(async_write_client): +async def test_update_retry_on_conflict(async_write_client: AsyncElasticsearch) -> None: await Wiki.init() w = Wiki(owner=User(name="Honza Kral"), _id="elasticsearch-py", views=42) await w.save() w1 = await Wiki.get(id="elasticsearch-py") w2 = await Wiki.get(id="elasticsearch-py") + assert w1 is not None + assert w2 is not None + await w1.update( script="ctx._source.views += params.inc", inc=5, retry_on_conflict=1 ) @@ -287,13 +301,18 @@ async def test_update_retry_on_conflict(async_write_client): @pytest.mark.asyncio @pytest.mark.parametrize("retry_on_conflict", [None, 0]) -async def test_update_conflicting_version(async_write_client, retry_on_conflict): +async def test_update_conflicting_version( + async_write_client: AsyncElasticsearch, retry_on_conflict: bool +) -> None: await Wiki.init() w = Wiki(owner=User(name="Honza Kral"), _id="elasticsearch-py", views=42) await w.save() w1 = await Wiki.get(id="elasticsearch-py") w2 = await Wiki.get(id="elasticsearch-py") + assert w1 is not None + assert w2 is not None + await w1.update(script="ctx._source.views += params.inc", inc=5) with raises(ConflictError): @@ -305,7 +324,9 @@ async def test_update_conflicting_version(async_write_client, retry_on_conflict) @pytest.mark.asyncio -async def test_save_and_update_return_doc_meta(async_write_client): +async def test_save_and_update_return_doc_meta( + async_write_client: AsyncElasticsearch, +) -> None: await Wiki.init() w = Wiki(owner=User(name="Honza Kral"), _id="elasticsearch-py", views=42) resp = await w.save(return_doc_meta=True) @@ -338,26 +359,32 @@ async def test_save_and_update_return_doc_meta(async_write_client): @pytest.mark.asyncio -async def test_init(async_write_client): +async def test_init(async_write_client: AsyncElasticsearch) -> None: await Repository.init(index="test-git") assert await async_write_client.indices.exists(index="test-git") @pytest.mark.asyncio -async def test_get_raises_404_on_index_missing(async_data_client): +async def test_get_raises_404_on_index_missing( + async_data_client: AsyncElasticsearch, +) -> None: with raises(NotFoundError): await Repository.get("elasticsearch-dsl-php", index="not-there") @pytest.mark.asyncio -async def test_get_raises_404_on_non_existent_id(async_data_client): +async def test_get_raises_404_on_non_existent_id( + async_data_client: AsyncElasticsearch, +) -> None: with raises(NotFoundError): await Repository.get("elasticsearch-dsl-php") @pytest.mark.asyncio -async def test_get_returns_none_if_404_ignored(async_data_client): +async def test_get_returns_none_if_404_ignored( + async_data_client: AsyncElasticsearch, +) -> None: assert None is await Repository.get( "elasticsearch-dsl-php", using=async_data_client.options(ignore_status=404) ) @@ -365,15 +392,15 @@ async def test_get_returns_none_if_404_ignored(async_data_client): @pytest.mark.asyncio async def test_get_returns_none_if_404_ignored_and_index_doesnt_exist( - async_data_client, -): + async_data_client: AsyncElasticsearch, +) -> None: assert None is await Repository.get( "42", index="not-there", using=async_data_client.options(ignore_status=404) ) @pytest.mark.asyncio -async def test_get(async_data_client): +async def test_get(async_data_client: AsyncElasticsearch) -> None: elasticsearch_repo = await Repository.get("elasticsearch-dsl-py") assert isinstance(elasticsearch_repo, Repository) @@ -382,20 +409,21 @@ async def test_get(async_data_client): @pytest.mark.asyncio -async def test_exists_return_true(async_data_client): +async def test_exists_return_true(async_data_client: AsyncElasticsearch) -> None: assert await Repository.exists("elasticsearch-dsl-py") @pytest.mark.asyncio -async def test_exists_false(async_data_client): +async def test_exists_false(async_data_client: AsyncElasticsearch) -> None: assert not await Repository.exists("elasticsearch-dsl-php") @pytest.mark.asyncio -async def test_get_with_tz_date(async_data_client): +async def test_get_with_tz_date(async_data_client: AsyncElasticsearch) -> None: first_commit = await Commit.get( id="3ca6e1e73a071a705b4babd2f581c91a2a3e5037", routing="elasticsearch-dsl-py" ) + assert first_commit is not None tzinfo = timezone("Europe/Prague") assert ( @@ -405,11 +433,13 @@ async def test_get_with_tz_date(async_data_client): @pytest.mark.asyncio -async def test_save_with_tz_date(async_data_client): +async def test_save_with_tz_date(async_data_client: AsyncElasticsearch) -> None: tzinfo = timezone("Europe/Prague") first_commit = await Commit.get( id="3ca6e1e73a071a705b4babd2f581c91a2a3e5037", routing="elasticsearch-dsl-py" ) + assert first_commit is not None + first_commit.committed_date = tzinfo.localize( datetime(2014, 5, 2, 13, 47, 19, 123456) ) @@ -418,6 +448,8 @@ async def test_save_with_tz_date(async_data_client): first_commit = await Commit.get( id="3ca6e1e73a071a705b4babd2f581c91a2a3e5037", routing="elasticsearch-dsl-py" ) + assert first_commit is not None + assert ( tzinfo.localize(datetime(2014, 5, 2, 13, 47, 19, 123456)) == first_commit.committed_date @@ -433,48 +465,62 @@ async def test_save_with_tz_date(async_data_client): @pytest.mark.asyncio -async def test_mget(async_data_client): +async def test_mget(async_data_client: AsyncElasticsearch) -> None: commits = await Commit.mget(COMMIT_DOCS_WITH_MISSING) assert commits[0] is None + assert commits[1] is not None assert commits[1].meta.id == "3ca6e1e73a071a705b4babd2f581c91a2a3e5037" assert commits[2] is None + assert commits[3] is not None assert commits[3].meta.id == "eb3e543323f189fd7b698e66295427204fff5755" @pytest.mark.asyncio -async def test_mget_raises_exception_when_missing_param_is_invalid(async_data_client): +async def test_mget_raises_exception_when_missing_param_is_invalid( + async_data_client: AsyncElasticsearch, +) -> None: with raises(ValueError): await Commit.mget(COMMIT_DOCS_WITH_MISSING, missing="raj") @pytest.mark.asyncio -async def test_mget_raises_404_when_missing_param_is_raise(async_data_client): +async def test_mget_raises_404_when_missing_param_is_raise( + async_data_client: AsyncElasticsearch, +) -> None: with raises(NotFoundError): await Commit.mget(COMMIT_DOCS_WITH_MISSING, missing="raise") @pytest.mark.asyncio -async def test_mget_ignores_missing_docs_when_missing_param_is_skip(async_data_client): +async def test_mget_ignores_missing_docs_when_missing_param_is_skip( + async_data_client: AsyncElasticsearch, +) -> None: commits = await Commit.mget(COMMIT_DOCS_WITH_MISSING, missing="skip") + assert commits[0] is not None assert commits[0].meta.id == "3ca6e1e73a071a705b4babd2f581c91a2a3e5037" + assert commits[1] is not None assert commits[1].meta.id == "eb3e543323f189fd7b698e66295427204fff5755" @pytest.mark.asyncio -async def test_update_works_from_search_response(async_data_client): +async def test_update_works_from_search_response( + async_data_client: AsyncElasticsearch, +) -> None: elasticsearch_repo = (await Repository.search().execute())[0] await elasticsearch_repo.update(owner={"other_name": "elastic"}) assert "elastic" == elasticsearch_repo.owner.other_name new_version = await Repository.get("elasticsearch-dsl-py") + assert new_version is not None assert "elastic" == new_version.owner.other_name assert "elasticsearch" == new_version.owner.name @pytest.mark.asyncio -async def test_update(async_data_client): +async def test_update(async_data_client: AsyncElasticsearch) -> None: elasticsearch_repo = await Repository.get("elasticsearch-dsl-py") + assert elasticsearch_repo is not None v = elasticsearch_repo.meta.version old_seq_no = elasticsearch_repo.meta.seq_no @@ -489,6 +535,7 @@ async def test_update(async_data_client): assert elasticsearch_repo.meta.version == v + 1 new_version = await Repository.get("elasticsearch-dsl-py") + assert new_version is not None assert "testing-update" == new_version.new_field assert "elastic" == new_version.owner.new_name assert "elasticsearch" == new_version.owner.name @@ -498,8 +545,9 @@ async def test_update(async_data_client): @pytest.mark.asyncio -async def test_save_updates_existing_doc(async_data_client): +async def test_save_updates_existing_doc(async_data_client: AsyncElasticsearch) -> None: elasticsearch_repo = await Repository.get("elasticsearch-dsl-py") + assert elasticsearch_repo is not None elasticsearch_repo.new_field = "testing-save" old_seq_no = elasticsearch_repo.meta.seq_no @@ -512,7 +560,7 @@ async def test_save_updates_existing_doc(async_data_client): @pytest.mark.asyncio -async def test_update_empty_field(async_client): +async def test_update_empty_field(async_client: AsyncElasticsearch) -> None: await Tags._index.delete(ignore_unavailable=True) await Tags.init() d = Tags(id="123", tags=["a", "b"]) @@ -525,8 +573,11 @@ async def test_update_empty_field(async_client): @pytest.mark.asyncio -async def test_save_automatically_uses_seq_no_and_primary_term(async_data_client): +async def test_save_automatically_uses_seq_no_and_primary_term( + async_data_client: AsyncElasticsearch, +) -> None: elasticsearch_repo = await Repository.get("elasticsearch-dsl-py") + assert elasticsearch_repo is not None elasticsearch_repo.meta.seq_no += 1 with raises(ConflictError): @@ -534,22 +585,27 @@ async def test_save_automatically_uses_seq_no_and_primary_term(async_data_client @pytest.mark.asyncio -async def test_delete_automatically_uses_seq_no_and_primary_term(async_data_client): +async def test_delete_automatically_uses_seq_no_and_primary_term( + async_data_client: AsyncElasticsearch, +) -> None: elasticsearch_repo = await Repository.get("elasticsearch-dsl-py") + assert elasticsearch_repo is not None elasticsearch_repo.meta.seq_no += 1 with raises(ConflictError): await elasticsearch_repo.delete() -def assert_doc_equals(expected, actual): +def assert_doc_equals(expected: Any, actual: Any) -> None: for f in expected: assert f in actual assert actual[f] == expected[f] @pytest.mark.asyncio -async def test_can_save_to_different_index(async_write_client): +async def test_can_save_to_different_index( + async_write_client: AsyncElasticsearch, +) -> None: test_repo = Repository(description="testing", meta={"id": 42}) assert await test_repo.save(index="test-document") @@ -565,7 +621,9 @@ async def test_can_save_to_different_index(async_write_client): @pytest.mark.asyncio -async def test_save_without_skip_empty_will_include_empty_fields(async_write_client): +async def test_save_without_skip_empty_will_include_empty_fields( + async_write_client: AsyncElasticsearch, +) -> None: test_repo = Repository(field_1=[], field_2=None, field_3={}, meta={"id": 42}) assert await test_repo.save(index="test-document", skip_empty=False) @@ -581,7 +639,7 @@ async def test_save_without_skip_empty_will_include_empty_fields(async_write_cli @pytest.mark.asyncio -async def test_delete(async_write_client): +async def test_delete(async_write_client: AsyncElasticsearch) -> None: await async_write_client.create( index="test-document", id="elasticsearch-dsl-py", @@ -603,12 +661,14 @@ async def test_delete(async_write_client): @pytest.mark.asyncio -async def test_search(async_data_client): +async def test_search(async_data_client: AsyncElasticsearch) -> None: assert await Repository.search().count() == 1 @pytest.mark.asyncio -async def test_search_returns_proper_doc_classes(async_data_client): +async def test_search_returns_proper_doc_classes( + async_data_client: AsyncElasticsearch, +) -> None: result = await Repository.search().execute() elasticsearch_repo = result.hits[0] @@ -618,7 +678,7 @@ async def test_search_returns_proper_doc_classes(async_data_client): @pytest.mark.asyncio -async def test_refresh_mapping(async_data_client): +async def test_refresh_mapping(async_data_client: AsyncElasticsearch) -> None: class Commit(AsyncDocument): class Index: name = "git" @@ -633,7 +693,7 @@ class Index: @pytest.mark.asyncio -async def test_highlight_in_meta(async_data_client): +async def test_highlight_in_meta(async_data_client: AsyncElasticsearch) -> None: commit = ( await Commit.search() .query("match", description="inverting") diff --git a/tests/test_integration/_async/test_faceted_search.py b/tests/test_integration/_async/test_faceted_search.py index 7b23ab91..7ba6e5de 100644 --- a/tests/test_integration/_async/test_faceted_search.py +++ b/tests/test_integration/_async/test_faceted_search.py @@ -16,10 +16,12 @@ # under the License. from datetime import datetime +from typing import Tuple, Type import pytest +from elasticsearch import AsyncElasticsearch -from elasticsearch_dsl import A, AsyncDocument, Boolean, Date, Keyword +from elasticsearch_dsl import A, AsyncDocument, AsyncSearch, Boolean, Date, Keyword from elasticsearch_dsl.faceted_search import ( AsyncFacetedSearch, DateHistogramFacet, @@ -57,7 +59,7 @@ class MetricSearch(AsyncFacetedSearch): @pytest.fixture(scope="session") -def commit_search_cls(es_version): +def commit_search_cls(es_version: Tuple[int, ...]) -> Type[AsyncFacetedSearch]: if es_version >= (7, 2): interval_kwargs = {"fixed_interval": "1d"} else: @@ -85,7 +87,7 @@ class CommitSearch(AsyncFacetedSearch): @pytest.fixture(scope="session") -def repo_search_cls(es_version): +def repo_search_cls(es_version: Tuple[int, ...]) -> Type[AsyncFacetedSearch]: interval_type = "calendar_interval" if es_version >= (7, 2) else "interval" class RepoSearch(AsyncFacetedSearch): @@ -98,7 +100,7 @@ class RepoSearch(AsyncFacetedSearch): ), } - def search(self): + def search(self) -> AsyncSearch: s = super().search() return s.filter("term", commit_repo="repo") @@ -106,7 +108,7 @@ def search(self): @pytest.fixture(scope="session") -def pr_search_cls(es_version): +def pr_search_cls(es_version: Tuple[int, ...]) -> Type[AsyncFacetedSearch]: interval_type = "calendar_interval" if es_version >= (7, 2) else "interval" class PRSearch(AsyncFacetedSearch): @@ -125,7 +127,7 @@ class PRSearch(AsyncFacetedSearch): @pytest.mark.asyncio -async def test_facet_with_custom_metric(async_data_client): +async def test_facet_with_custom_metric(async_data_client: AsyncElasticsearch) -> None: ms = MetricSearch() r = await ms.execute() @@ -135,20 +137,24 @@ async def test_facet_with_custom_metric(async_data_client): @pytest.mark.asyncio -async def test_nested_facet(async_pull_request, pr_search_cls): +async def test_nested_facet( + async_pull_request: PullRequest, pr_search_cls: Type[AsyncFacetedSearch] +) -> None: prs = pr_search_cls() r = await prs.execute() - assert r.hits.total.value == 1 + assert r.hits.total.value == 1 # type: ignore[attr-defined] assert [(datetime(2018, 1, 1, 0, 0), 1, False)] == r.facets.comments @pytest.mark.asyncio -async def test_nested_facet_with_filter(async_pull_request, pr_search_cls): +async def test_nested_facet_with_filter( + async_pull_request: PullRequest, pr_search_cls: Type[AsyncFacetedSearch] +) -> None: prs = pr_search_cls(filters={"comments": datetime(2018, 1, 1, 0, 0)}) r = await prs.execute() - assert r.hits.total.value == 1 + assert r.hits.total.value == 1 # type: ignore[attr-defined] assert [(datetime(2018, 1, 1, 0, 0), 1, True)] == r.facets.comments prs = pr_search_cls(filters={"comments": datetime(2018, 2, 1, 0, 0)}) @@ -157,20 +163,24 @@ async def test_nested_facet_with_filter(async_pull_request, pr_search_cls): @pytest.mark.asyncio -async def test_datehistogram_facet(async_data_client, repo_search_cls): +async def test_datehistogram_facet( + async_data_client: AsyncElasticsearch, repo_search_cls: Type[AsyncFacetedSearch] +) -> None: rs = repo_search_cls() r = await rs.execute() - assert r.hits.total.value == 1 + assert r.hits.total.value == 1 # type: ignore[attr-defined] assert [(datetime(2014, 3, 1, 0, 0), 1, False)] == r.facets.created @pytest.mark.asyncio -async def test_boolean_facet(async_data_client, repo_search_cls): +async def test_boolean_facet( + async_data_client: AsyncElasticsearch, repo_search_cls: Type[AsyncFacetedSearch] +) -> None: rs = repo_search_cls() r = await rs.execute() - assert r.hits.total.value == 1 + assert r.hits.total.value == 1 # type: ignore[attr-defined] assert [(True, 1, False)] == r.facets.public value, count, selected = r.facets.public[0] assert value is True @@ -178,12 +188,14 @@ async def test_boolean_facet(async_data_client, repo_search_cls): @pytest.mark.asyncio async def test_empty_search_finds_everything( - async_data_client, es_version, commit_search_cls -): + async_data_client: AsyncElasticsearch, + es_version: Tuple[int, ...], + commit_search_cls: Type[AsyncFacetedSearch], +) -> None: cs = commit_search_cls() r = await cs.execute() - assert r.hits.total.value == 52 + assert r.hits.total.value == 52 # type: ignore[attr-defined] assert [ ("elasticsearch_dsl", 40, False), ("test_elasticsearch_dsl", 35, False), @@ -226,13 +238,13 @@ async def test_empty_search_finds_everything( @pytest.mark.asyncio async def test_term_filters_are_shown_as_selected_and_data_is_filtered( - async_data_client, commit_search_cls -): + async_data_client: AsyncElasticsearch, commit_search_cls: Type[AsyncFacetedSearch] +) -> None: cs = commit_search_cls(filters={"files": "test_elasticsearch_dsl"}) r = await cs.execute() - assert 35 == r.hits.total.value + assert 35 == r.hits.total.value # type: ignore[attr-defined] assert [ ("elasticsearch_dsl", 40, False), ("test_elasticsearch_dsl", 35, True), # selected @@ -273,17 +285,19 @@ async def test_term_filters_are_shown_as_selected_and_data_is_filtered( @pytest.mark.asyncio async def test_range_filters_are_shown_as_selected_and_data_is_filtered( - async_data_client, commit_search_cls -): + async_data_client: AsyncElasticsearch, commit_search_cls: Type[AsyncFacetedSearch] +) -> None: cs = commit_search_cls(filters={"deletions": "better"}) r = await cs.execute() - assert 19 == r.hits.total.value + assert 19 == r.hits.total.value # type: ignore[attr-defined] @pytest.mark.asyncio -async def test_pagination(async_data_client, commit_search_cls): +async def test_pagination( + async_data_client: AsyncElasticsearch, commit_search_cls: Type[AsyncFacetedSearch] +) -> None: cs = commit_search_cls() cs = cs[0:20] diff --git a/tests/test_integration/_async/test_index.py b/tests/test_integration/_async/test_index.py index efbd711d..21e4fa7c 100644 --- a/tests/test_integration/_async/test_index.py +++ b/tests/test_integration/_async/test_index.py @@ -16,6 +16,7 @@ # under the License. import pytest +from elasticsearch import AsyncElasticsearch from elasticsearch_dsl import ( AsyncDocument, @@ -33,7 +34,7 @@ class Post(AsyncDocument): @pytest.mark.asyncio -async def test_index_template_works(async_write_client): +async def test_index_template_works(async_write_client: AsyncElasticsearch) -> None: it = AsyncIndexTemplate("test-template", "test-*") it.document(Post) it.settings(number_of_replicas=0, number_of_shards=1) @@ -55,7 +56,9 @@ async def test_index_template_works(async_write_client): @pytest.mark.asyncio -async def test_index_can_be_saved_even_with_settings(async_write_client): +async def test_index_can_be_saved_even_with_settings( + async_write_client: AsyncElasticsearch, +) -> None: i = AsyncIndex("test-blog", using=async_write_client) i.settings(number_of_shards=3, number_of_replicas=0) await i.save() @@ -71,13 +74,15 @@ async def test_index_can_be_saved_even_with_settings(async_write_client): @pytest.mark.asyncio -async def test_index_exists(async_data_client): +async def test_index_exists(async_data_client: AsyncElasticsearch) -> None: assert await AsyncIndex("git").exists() assert not await AsyncIndex("not-there").exists() @pytest.mark.asyncio -async def test_index_can_be_created_with_settings_and_mappings(async_write_client): +async def test_index_can_be_created_with_settings_and_mappings( + async_write_client: AsyncElasticsearch, +) -> None: i = AsyncIndex("test-blog", using=async_write_client) i.document(Post) i.settings(number_of_replicas=0, number_of_shards=1) @@ -103,7 +108,7 @@ async def test_index_can_be_created_with_settings_and_mappings(async_write_clien @pytest.mark.asyncio -async def test_delete(async_write_client): +async def test_delete(async_write_client: AsyncElasticsearch) -> None: await async_write_client.indices.create( index="test-index", body={"settings": {"number_of_replicas": 0, "number_of_shards": 1}}, @@ -115,7 +120,9 @@ async def test_delete(async_write_client): @pytest.mark.asyncio -async def test_multiple_indices_with_same_doc_type_work(async_write_client): +async def test_multiple_indices_with_same_doc_type_work( + async_write_client: AsyncElasticsearch, +) -> None: i1 = AsyncIndex("test-index-1", using=async_write_client) i2 = AsyncIndex("test-index-2", using=async_write_client) @@ -123,8 +130,8 @@ async def test_multiple_indices_with_same_doc_type_work(async_write_client): i.document(Post) await i.create() - for i in ("test-index-1", "test-index-2"): - settings = await async_write_client.indices.get_settings(index=i) - assert settings[i]["settings"]["index"]["analysis"] == { + for j in ("test-index-1", "test-index-2"): + settings = await async_write_client.indices.get_settings(index=j) + assert settings[j]["settings"]["index"]["analysis"] == { "analyzer": {"my_analyzer": {"type": "custom", "tokenizer": "keyword"}} } diff --git a/tests/test_integration/_async/test_mapping.py b/tests/test_integration/_async/test_mapping.py index 0c016d7b..c2a78a4c 100644 --- a/tests/test_integration/_async/test_mapping.py +++ b/tests/test_integration/_async/test_mapping.py @@ -16,13 +16,14 @@ # under the License. import pytest +from elasticsearch import AsyncElasticsearch from pytest import raises from elasticsearch_dsl import AsyncMapping, analysis, exceptions @pytest.mark.asyncio -async def test_mapping_saved_into_es(async_write_client): +async def test_mapping_saved_into_es(async_write_client: AsyncElasticsearch) -> None: m = AsyncMapping() m.field( "name", "text", analyzer=analysis.analyzer("my_analyzer", tokenizer="keyword") @@ -44,8 +45,8 @@ async def test_mapping_saved_into_es(async_write_client): @pytest.mark.asyncio async def test_mapping_saved_into_es_when_index_already_exists_closed( - async_write_client, -): + async_write_client: AsyncElasticsearch, +) -> None: m = AsyncMapping() m.field( "name", "text", analyzer=analysis.analyzer("my_analyzer", tokenizer="keyword") @@ -72,8 +73,8 @@ async def test_mapping_saved_into_es_when_index_already_exists_closed( @pytest.mark.asyncio async def test_mapping_saved_into_es_when_index_already_exists_with_analysis( - async_write_client, -): + async_write_client: AsyncElasticsearch, +) -> None: m = AsyncMapping() analyzer = analysis.analyzer("my_analyzer", tokenizer="keyword") m.field("name", "text", analyzer=analyzer) @@ -103,7 +104,9 @@ async def test_mapping_saved_into_es_when_index_already_exists_with_analysis( @pytest.mark.asyncio -async def test_mapping_gets_updated_from_es(async_write_client): +async def test_mapping_gets_updated_from_es( + async_write_client: AsyncElasticsearch, +) -> None: await async_write_client.indices.create( index="test-mapping", body={ @@ -136,7 +139,7 @@ async def test_mapping_gets_updated_from_es(async_write_client): m = await AsyncMapping.from_es("test-mapping", using=async_write_client) assert ["comments", "created_at", "title"] == list( - sorted(m.properties.properties._d_.keys()) + sorted(m.properties.properties._d_.keys()) # type: ignore[attr-defined] ) assert { "date_detection": False, diff --git a/tests/test_integration/_async/test_search.py b/tests/test_integration/_async/test_search.py index 2c329ee8..11bc8c72 100644 --- a/tests/test_integration/_async/test_search.py +++ b/tests/test_integration/_async/test_search.py @@ -17,7 +17,7 @@ import pytest -from elasticsearch import ApiError +from elasticsearch import ApiError, AsyncElasticsearch from pytest import raises from elasticsearch_dsl import ( @@ -40,7 +40,7 @@ class Repository(AsyncDocument): tags = Keyword() @classmethod - def search(cls): + def search(cls) -> AsyncSearch["Repository"]: # type: ignore[override] return super().search().filter("term", commit_repo="repo") class Index: @@ -53,7 +53,9 @@ class Index: @pytest.mark.asyncio -async def test_filters_aggregation_buckets_are_accessible(async_data_client): +async def test_filters_aggregation_buckets_are_accessible( + async_data_client: AsyncElasticsearch, +) -> None: has_tests_query = Q("term", files="test_elasticsearch_dsl") s = Commit.search()[0:0] s.aggs.bucket("top_authors", "terms", field="author.name.raw").bucket( @@ -76,7 +78,9 @@ async def test_filters_aggregation_buckets_are_accessible(async_data_client): @pytest.mark.asyncio -async def test_top_hits_are_wrapped_in_response(async_data_client): +async def test_top_hits_are_wrapped_in_response( + async_data_client: AsyncElasticsearch, +) -> None: s = Commit.search()[0:0] s.aggs.bucket("top_authors", "terms", field="author.name.raw").metric( "top_commits", "top_hits", size=5 @@ -93,7 +97,9 @@ async def test_top_hits_are_wrapped_in_response(async_data_client): @pytest.mark.asyncio -async def test_inner_hits_are_wrapped_in_response(async_data_client): +async def test_inner_hits_are_wrapped_in_response( + async_data_client: AsyncElasticsearch, +) -> None: s = AsyncSearch(index="git")[0:1].query( "has_parent", parent_type="repo", inner_hits={}, query=Q("match_all") ) @@ -107,7 +113,7 @@ async def test_inner_hits_are_wrapped_in_response(async_data_client): @pytest.mark.asyncio -async def test_scan_respects_doc_types(async_data_client): +async def test_scan_respects_doc_types(async_data_client: AsyncElasticsearch) -> None: repos = [repo async for repo in Repository.search().scan()] assert 1 == len(repos) @@ -116,7 +122,9 @@ async def test_scan_respects_doc_types(async_data_client): @pytest.mark.asyncio -async def test_scan_iterates_through_all_docs(async_data_client): +async def test_scan_iterates_through_all_docs( + async_data_client: AsyncElasticsearch, +) -> None: s = AsyncSearch(index="flat-git") commits = [commit async for commit in s.scan()] @@ -126,7 +134,7 @@ async def test_scan_iterates_through_all_docs(async_data_client): @pytest.mark.asyncio -async def test_search_after(async_data_client): +async def test_search_after(async_data_client: AsyncElasticsearch) -> None: page_size = 7 s = AsyncSearch(index="flat-git")[:page_size].sort("authored_date") commits = [] @@ -135,52 +143,52 @@ async def test_search_after(async_data_client): commits += r.hits if len(r.hits) < page_size: break - s = r.search_after() + s = s.search_after() assert 52 == len(commits) assert {d["_id"] for d in FLAT_DATA} == {c.meta.id for c in commits} @pytest.mark.asyncio -async def test_search_after_no_search(async_data_client): +async def test_search_after_no_search(async_data_client: AsyncElasticsearch) -> None: s = AsyncSearch(index="flat-git") with raises( ValueError, match="A search must be executed before using search_after" ): - await s.search_after() + s.search_after() await s.count() with raises( ValueError, match="A search must be executed before using search_after" ): - await s.search_after() + s.search_after() @pytest.mark.asyncio -async def test_search_after_no_sort(async_data_client): +async def test_search_after_no_sort(async_data_client: AsyncElasticsearch) -> None: s = AsyncSearch(index="flat-git") r = await s.execute() with raises( ValueError, match="Cannot use search_after when results are not sorted" ): - await r.search_after() + r.search_after() @pytest.mark.asyncio -async def test_search_after_no_results(async_data_client): +async def test_search_after_no_results(async_data_client: AsyncElasticsearch) -> None: s = AsyncSearch(index="flat-git")[:100].sort("authored_date") r = await s.execute() assert 52 == len(r.hits) - s = r.search_after() + s = s.search_after() r = await s.execute() assert 0 == len(r.hits) with raises( ValueError, match="Cannot use search_after when there are no search results" ): - await r.search_after() + r.search_after() @pytest.mark.asyncio -async def test_point_in_time(async_data_client): +async def test_point_in_time(async_data_client: AsyncElasticsearch) -> None: page_size = 7 commits = [] async with AsyncSearch(index="flat-git")[:page_size].point_in_time( @@ -192,7 +200,7 @@ async def test_point_in_time(async_data_client): commits += r.hits if len(r.hits) < page_size: break - s = r.search_after() + s = s.search_after() assert pit_id == s._extra["pit"]["id"] assert "30s" == s._extra["pit"]["keep_alive"] @@ -201,7 +209,7 @@ async def test_point_in_time(async_data_client): @pytest.mark.asyncio -async def test_iterate(async_data_client): +async def test_iterate(async_data_client: AsyncElasticsearch) -> None: s = AsyncSearch(index="flat-git") commits = [commit async for commit in s.iterate()] @@ -211,7 +219,7 @@ async def test_iterate(async_data_client): @pytest.mark.asyncio -async def test_response_is_cached(async_data_client): +async def test_response_is_cached(async_data_client: AsyncElasticsearch) -> None: s = Repository.search() repos = [repo async for repo in s] @@ -220,11 +228,11 @@ async def test_response_is_cached(async_data_client): @pytest.mark.asyncio -async def test_multi_search(async_data_client): +async def test_multi_search(async_data_client: AsyncElasticsearch) -> None: s1 = Repository.search() - s2 = AsyncSearch(index="flat-git") + s2 = AsyncSearch[Repository](index="flat-git") - ms = AsyncMultiSearch() + ms = AsyncMultiSearch[Repository]() ms = ms.add(s1).add(s2) r1, r2 = await ms.execute() @@ -233,17 +241,17 @@ async def test_multi_search(async_data_client): assert isinstance(r1[0], Repository) assert r1._search is s1 - assert 52 == r2.hits.total.value + assert 52 == r2.hits.total.value # type: ignore[attr-defined] assert r2._search is s2 @pytest.mark.asyncio -async def test_multi_missing(async_data_client): +async def test_multi_missing(async_data_client: AsyncElasticsearch) -> None: s1 = Repository.search() - s2 = AsyncSearch(index="flat-git") - s3 = AsyncSearch(index="does_not_exist") + s2 = AsyncSearch[Repository](index="flat-git") + s3 = AsyncSearch[Repository](index="does_not_exist") - ms = AsyncMultiSearch() + ms = AsyncMultiSearch[Repository]() ms = ms.add(s1).add(s2).add(s3) with raises(ApiError): @@ -255,14 +263,16 @@ async def test_multi_missing(async_data_client): assert isinstance(r1[0], Repository) assert r1._search is s1 - assert 52 == r2.hits.total.value + assert 52 == r2.hits.total.value # type: ignore[attr-defined] assert r2._search is s2 assert r3 is None @pytest.mark.asyncio -async def test_raw_subfield_can_be_used_in_aggs(async_data_client): +async def test_raw_subfield_can_be_used_in_aggs( + async_data_client: AsyncElasticsearch, +) -> None: s = AsyncSearch(index="git")[0:0] s.aggs.bucket("authors", "terms", field="author.name.raw", size=1) diff --git a/tests/test_integration/_async/test_update_by_query.py b/tests/test_integration/_async/test_update_by_query.py index 9d85a852..fccc4512 100644 --- a/tests/test_integration/_async/test_update_by_query.py +++ b/tests/test_integration/_async/test_update_by_query.py @@ -16,13 +16,16 @@ # under the License. import pytest +from elasticsearch import AsyncElasticsearch from elasticsearch_dsl import AsyncUpdateByQuery from elasticsearch_dsl.search import Q @pytest.mark.asyncio -async def test_update_by_query_no_script(async_write_client, setup_ubq_tests): +async def test_update_by_query_no_script( + async_write_client: AsyncElasticsearch, setup_ubq_tests: str +) -> None: index = setup_ubq_tests ubq = ( @@ -42,7 +45,9 @@ async def test_update_by_query_no_script(async_write_client, setup_ubq_tests): @pytest.mark.asyncio -async def test_update_by_query_with_script(async_write_client, setup_ubq_tests: str): +async def test_update_by_query_with_script( + async_write_client: AsyncElasticsearch, setup_ubq_tests: str +) -> None: index = setup_ubq_tests ubq = ( @@ -60,7 +65,9 @@ async def test_update_by_query_with_script(async_write_client, setup_ubq_tests: @pytest.mark.asyncio -async def test_delete_by_query_with_script(async_write_client, setup_ubq_tests: str): +async def test_delete_by_query_with_script( + async_write_client: AsyncElasticsearch, setup_ubq_tests: str +) -> None: index = setup_ubq_tests ubq = ( diff --git a/tests/test_integration/_sync/test_analysis.py b/tests/test_integration/_sync/test_analysis.py index 567710ff..20f34f58 100644 --- a/tests/test_integration/_sync/test_analysis.py +++ b/tests/test_integration/_sync/test_analysis.py @@ -16,12 +16,15 @@ # under the License. import pytest +from elasticsearch import Elasticsearch from elasticsearch_dsl import analyzer, token_filter, tokenizer @pytest.mark.sync -def test_simulate_with_just__builtin_tokenizer(client): +def test_simulate_with_just__builtin_tokenizer( + client: Elasticsearch, +) -> None: a = analyzer("my-analyzer", tokenizer="keyword") tokens = (a.simulate("Hello World!", using=client)).tokens @@ -30,7 +33,7 @@ def test_simulate_with_just__builtin_tokenizer(client): @pytest.mark.sync -def test_simulate_complex(client): +def test_simulate_complex(client: Elasticsearch) -> None: a = analyzer( "my-analyzer", tokenizer=tokenizer("split_words", "simple_pattern_split", pattern=":"), @@ -44,7 +47,7 @@ def test_simulate_complex(client): @pytest.mark.sync -def test_simulate_builtin(client): +def test_simulate_builtin(client: Elasticsearch) -> None: a = analyzer("my-analyzer", "english") tokens = (a.simulate("fixes running")).tokens diff --git a/tests/test_integration/_sync/test_document.py b/tests/test_integration/_sync/test_document.py index 9718f080..a102a333 100644 --- a/tests/test_integration/_sync/test_document.py +++ b/tests/test_integration/_sync/test_document.py @@ -15,11 +15,18 @@ # specific language governing permissions and limitations # under the License. +# this file creates several documents using bad or no types because +# these are still supported and should be kept functional in spite +# of not having appropriate type hints. For that reason the comment +# below disables many mypy checks that fails as a result of this. +# mypy: disable-error-code="assignment, index, arg-type, call-arg, operator, comparison-overlap, attr-defined" + from datetime import datetime from ipaddress import ip_address +from typing import Any import pytest -from elasticsearch import ConflictError, NotFoundError +from elasticsearch import ConflictError, Elasticsearch, NotFoundError from pytest import raises from pytz import timezone @@ -39,6 +46,7 @@ Object, Q, RankFeatures, + Search, Text, analyzer, ) @@ -67,7 +75,7 @@ class Repository(Document): tags = Keyword() @classmethod - def search(cls): + def search(cls) -> Search["Repository"]: # type: ignore[override] return super().search().filter("term", commit_repo="repo") class Index: @@ -128,7 +136,7 @@ class Index: @pytest.mark.sync -def test_serialization(write_client): +def test_serialization(write_client: Elasticsearch) -> None: SerializationDoc.init() write_client.index( index="test-serialization", @@ -142,6 +150,7 @@ def test_serialization(write_client): }, ) sd = SerializationDoc.get(id=42) + assert sd is not None assert sd.i == [1, 2, 3, None] assert sd.b == [True, False, True, False, None] @@ -159,7 +168,7 @@ def test_serialization(write_client): @pytest.mark.sync -def test_nested_inner_hits_are_wrapped_properly(pull_request): +def test_nested_inner_hits_are_wrapped_properly(pull_request: Any) -> None: history_query = Q( "nested", path="comments.history", @@ -188,7 +197,9 @@ def test_nested_inner_hits_are_wrapped_properly(pull_request): @pytest.mark.sync -def test_nested_inner_hits_are_deserialized_properly(pull_request): +def test_nested_inner_hits_are_deserialized_properly( + pull_request: Any, +) -> None: s = PullRequest.search().query( "nested", inner_hits={}, @@ -204,7 +215,7 @@ def test_nested_inner_hits_are_deserialized_properly(pull_request): @pytest.mark.sync -def test_nested_top_hits_are_wrapped_properly(pull_request): +def test_nested_top_hits_are_wrapped_properly(pull_request: Any) -> None: s = PullRequest.search() s.aggs.bucket("comments", "nested", path="comments").metric( "hits", "top_hits", size=1 @@ -217,7 +228,7 @@ def test_nested_top_hits_are_wrapped_properly(pull_request): @pytest.mark.sync -def test_update_object_field(write_client): +def test_update_object_field(write_client: Elasticsearch) -> None: Wiki.init() w = Wiki( owner=User(name="Honza Kral"), @@ -238,7 +249,7 @@ def test_update_object_field(write_client): @pytest.mark.sync -def test_update_script(write_client): +def test_update_script(write_client: Elasticsearch) -> None: Wiki.init() w = Wiki(owner=User(name="Honza Kral"), _id="elasticsearch-py", views=42) w.save() @@ -249,7 +260,7 @@ def test_update_script(write_client): @pytest.mark.sync -def test_update_script_with_dict(write_client): +def test_update_script_with_dict(write_client: Elasticsearch) -> None: Wiki.init() w = Wiki(owner=User(name="Honza Kral"), _id="elasticsearch-py", views=42) w.save() @@ -267,13 +278,16 @@ def test_update_script_with_dict(write_client): @pytest.mark.sync -def test_update_retry_on_conflict(write_client): +def test_update_retry_on_conflict(write_client: Elasticsearch) -> None: Wiki.init() w = Wiki(owner=User(name="Honza Kral"), _id="elasticsearch-py", views=42) w.save() w1 = Wiki.get(id="elasticsearch-py") w2 = Wiki.get(id="elasticsearch-py") + assert w1 is not None + assert w2 is not None + w1.update(script="ctx._source.views += params.inc", inc=5, retry_on_conflict=1) w2.update(script="ctx._source.views += params.inc", inc=5, retry_on_conflict=1) @@ -283,13 +297,18 @@ def test_update_retry_on_conflict(write_client): @pytest.mark.sync @pytest.mark.parametrize("retry_on_conflict", [None, 0]) -def test_update_conflicting_version(write_client, retry_on_conflict): +def test_update_conflicting_version( + write_client: Elasticsearch, retry_on_conflict: bool +) -> None: Wiki.init() w = Wiki(owner=User(name="Honza Kral"), _id="elasticsearch-py", views=42) w.save() w1 = Wiki.get(id="elasticsearch-py") w2 = Wiki.get(id="elasticsearch-py") + assert w1 is not None + assert w2 is not None + w1.update(script="ctx._source.views += params.inc", inc=5) with raises(ConflictError): @@ -301,7 +320,9 @@ def test_update_conflicting_version(write_client, retry_on_conflict): @pytest.mark.sync -def test_save_and_update_return_doc_meta(write_client): +def test_save_and_update_return_doc_meta( + write_client: Elasticsearch, +) -> None: Wiki.init() w = Wiki(owner=User(name="Honza Kral"), _id="elasticsearch-py", views=42) resp = w.save(return_doc_meta=True) @@ -334,26 +355,32 @@ def test_save_and_update_return_doc_meta(write_client): @pytest.mark.sync -def test_init(write_client): +def test_init(write_client: Elasticsearch) -> None: Repository.init(index="test-git") assert write_client.indices.exists(index="test-git") @pytest.mark.sync -def test_get_raises_404_on_index_missing(data_client): +def test_get_raises_404_on_index_missing( + data_client: Elasticsearch, +) -> None: with raises(NotFoundError): Repository.get("elasticsearch-dsl-php", index="not-there") @pytest.mark.sync -def test_get_raises_404_on_non_existent_id(data_client): +def test_get_raises_404_on_non_existent_id( + data_client: Elasticsearch, +) -> None: with raises(NotFoundError): Repository.get("elasticsearch-dsl-php") @pytest.mark.sync -def test_get_returns_none_if_404_ignored(data_client): +def test_get_returns_none_if_404_ignored( + data_client: Elasticsearch, +) -> None: assert None is Repository.get( "elasticsearch-dsl-php", using=data_client.options(ignore_status=404) ) @@ -361,15 +388,15 @@ def test_get_returns_none_if_404_ignored(data_client): @pytest.mark.sync def test_get_returns_none_if_404_ignored_and_index_doesnt_exist( - data_client, -): + data_client: Elasticsearch, +) -> None: assert None is Repository.get( "42", index="not-there", using=data_client.options(ignore_status=404) ) @pytest.mark.sync -def test_get(data_client): +def test_get(data_client: Elasticsearch) -> None: elasticsearch_repo = Repository.get("elasticsearch-dsl-py") assert isinstance(elasticsearch_repo, Repository) @@ -378,20 +405,21 @@ def test_get(data_client): @pytest.mark.sync -def test_exists_return_true(data_client): +def test_exists_return_true(data_client: Elasticsearch) -> None: assert Repository.exists("elasticsearch-dsl-py") @pytest.mark.sync -def test_exists_false(data_client): +def test_exists_false(data_client: Elasticsearch) -> None: assert not Repository.exists("elasticsearch-dsl-php") @pytest.mark.sync -def test_get_with_tz_date(data_client): +def test_get_with_tz_date(data_client: Elasticsearch) -> None: first_commit = Commit.get( id="3ca6e1e73a071a705b4babd2f581c91a2a3e5037", routing="elasticsearch-dsl-py" ) + assert first_commit is not None tzinfo = timezone("Europe/Prague") assert ( @@ -401,11 +429,13 @@ def test_get_with_tz_date(data_client): @pytest.mark.sync -def test_save_with_tz_date(data_client): +def test_save_with_tz_date(data_client: Elasticsearch) -> None: tzinfo = timezone("Europe/Prague") first_commit = Commit.get( id="3ca6e1e73a071a705b4babd2f581c91a2a3e5037", routing="elasticsearch-dsl-py" ) + assert first_commit is not None + first_commit.committed_date = tzinfo.localize( datetime(2014, 5, 2, 13, 47, 19, 123456) ) @@ -414,6 +444,8 @@ def test_save_with_tz_date(data_client): first_commit = Commit.get( id="3ca6e1e73a071a705b4babd2f581c91a2a3e5037", routing="elasticsearch-dsl-py" ) + assert first_commit is not None + assert ( tzinfo.localize(datetime(2014, 5, 2, 13, 47, 19, 123456)) == first_commit.committed_date @@ -429,48 +461,62 @@ def test_save_with_tz_date(data_client): @pytest.mark.sync -def test_mget(data_client): +def test_mget(data_client: Elasticsearch) -> None: commits = Commit.mget(COMMIT_DOCS_WITH_MISSING) assert commits[0] is None + assert commits[1] is not None assert commits[1].meta.id == "3ca6e1e73a071a705b4babd2f581c91a2a3e5037" assert commits[2] is None + assert commits[3] is not None assert commits[3].meta.id == "eb3e543323f189fd7b698e66295427204fff5755" @pytest.mark.sync -def test_mget_raises_exception_when_missing_param_is_invalid(data_client): +def test_mget_raises_exception_when_missing_param_is_invalid( + data_client: Elasticsearch, +) -> None: with raises(ValueError): Commit.mget(COMMIT_DOCS_WITH_MISSING, missing="raj") @pytest.mark.sync -def test_mget_raises_404_when_missing_param_is_raise(data_client): +def test_mget_raises_404_when_missing_param_is_raise( + data_client: Elasticsearch, +) -> None: with raises(NotFoundError): Commit.mget(COMMIT_DOCS_WITH_MISSING, missing="raise") @pytest.mark.sync -def test_mget_ignores_missing_docs_when_missing_param_is_skip(data_client): +def test_mget_ignores_missing_docs_when_missing_param_is_skip( + data_client: Elasticsearch, +) -> None: commits = Commit.mget(COMMIT_DOCS_WITH_MISSING, missing="skip") + assert commits[0] is not None assert commits[0].meta.id == "3ca6e1e73a071a705b4babd2f581c91a2a3e5037" + assert commits[1] is not None assert commits[1].meta.id == "eb3e543323f189fd7b698e66295427204fff5755" @pytest.mark.sync -def test_update_works_from_search_response(data_client): +def test_update_works_from_search_response( + data_client: Elasticsearch, +) -> None: elasticsearch_repo = (Repository.search().execute())[0] elasticsearch_repo.update(owner={"other_name": "elastic"}) assert "elastic" == elasticsearch_repo.owner.other_name new_version = Repository.get("elasticsearch-dsl-py") + assert new_version is not None assert "elastic" == new_version.owner.other_name assert "elasticsearch" == new_version.owner.name @pytest.mark.sync -def test_update(data_client): +def test_update(data_client: Elasticsearch) -> None: elasticsearch_repo = Repository.get("elasticsearch-dsl-py") + assert elasticsearch_repo is not None v = elasticsearch_repo.meta.version old_seq_no = elasticsearch_repo.meta.seq_no @@ -483,6 +529,7 @@ def test_update(data_client): assert elasticsearch_repo.meta.version == v + 1 new_version = Repository.get("elasticsearch-dsl-py") + assert new_version is not None assert "testing-update" == new_version.new_field assert "elastic" == new_version.owner.new_name assert "elasticsearch" == new_version.owner.name @@ -492,8 +539,9 @@ def test_update(data_client): @pytest.mark.sync -def test_save_updates_existing_doc(data_client): +def test_save_updates_existing_doc(data_client: Elasticsearch) -> None: elasticsearch_repo = Repository.get("elasticsearch-dsl-py") + assert elasticsearch_repo is not None elasticsearch_repo.new_field = "testing-save" old_seq_no = elasticsearch_repo.meta.seq_no @@ -506,7 +554,7 @@ def test_save_updates_existing_doc(data_client): @pytest.mark.sync -def test_update_empty_field(client): +def test_update_empty_field(client: Elasticsearch) -> None: Tags._index.delete(ignore_unavailable=True) Tags.init() d = Tags(id="123", tags=["a", "b"]) @@ -519,8 +567,11 @@ def test_update_empty_field(client): @pytest.mark.sync -def test_save_automatically_uses_seq_no_and_primary_term(data_client): +def test_save_automatically_uses_seq_no_and_primary_term( + data_client: Elasticsearch, +) -> None: elasticsearch_repo = Repository.get("elasticsearch-dsl-py") + assert elasticsearch_repo is not None elasticsearch_repo.meta.seq_no += 1 with raises(ConflictError): @@ -528,22 +579,27 @@ def test_save_automatically_uses_seq_no_and_primary_term(data_client): @pytest.mark.sync -def test_delete_automatically_uses_seq_no_and_primary_term(data_client): +def test_delete_automatically_uses_seq_no_and_primary_term( + data_client: Elasticsearch, +) -> None: elasticsearch_repo = Repository.get("elasticsearch-dsl-py") + assert elasticsearch_repo is not None elasticsearch_repo.meta.seq_no += 1 with raises(ConflictError): elasticsearch_repo.delete() -def assert_doc_equals(expected, actual): +def assert_doc_equals(expected: Any, actual: Any) -> None: for f in expected: assert f in actual assert actual[f] == expected[f] @pytest.mark.sync -def test_can_save_to_different_index(write_client): +def test_can_save_to_different_index( + write_client: Elasticsearch, +) -> None: test_repo = Repository(description="testing", meta={"id": 42}) assert test_repo.save(index="test-document") @@ -559,7 +615,9 @@ def test_can_save_to_different_index(write_client): @pytest.mark.sync -def test_save_without_skip_empty_will_include_empty_fields(write_client): +def test_save_without_skip_empty_will_include_empty_fields( + write_client: Elasticsearch, +) -> None: test_repo = Repository(field_1=[], field_2=None, field_3={}, meta={"id": 42}) assert test_repo.save(index="test-document", skip_empty=False) @@ -575,7 +633,7 @@ def test_save_without_skip_empty_will_include_empty_fields(write_client): @pytest.mark.sync -def test_delete(write_client): +def test_delete(write_client: Elasticsearch) -> None: write_client.create( index="test-document", id="elasticsearch-dsl-py", @@ -597,12 +655,14 @@ def test_delete(write_client): @pytest.mark.sync -def test_search(data_client): +def test_search(data_client: Elasticsearch) -> None: assert Repository.search().count() == 1 @pytest.mark.sync -def test_search_returns_proper_doc_classes(data_client): +def test_search_returns_proper_doc_classes( + data_client: Elasticsearch, +) -> None: result = Repository.search().execute() elasticsearch_repo = result.hits[0] @@ -612,7 +672,7 @@ def test_search_returns_proper_doc_classes(data_client): @pytest.mark.sync -def test_refresh_mapping(data_client): +def test_refresh_mapping(data_client: Elasticsearch) -> None: class Commit(Document): class Index: name = "git" @@ -627,7 +687,7 @@ class Index: @pytest.mark.sync -def test_highlight_in_meta(data_client): +def test_highlight_in_meta(data_client: Elasticsearch) -> None: commit = ( Commit.search() .query("match", description="inverting") diff --git a/tests/test_integration/_sync/test_faceted_search.py b/tests/test_integration/_sync/test_faceted_search.py index ab171631..192f57f5 100644 --- a/tests/test_integration/_sync/test_faceted_search.py +++ b/tests/test_integration/_sync/test_faceted_search.py @@ -16,10 +16,12 @@ # under the License. from datetime import datetime +from typing import Tuple, Type import pytest +from elasticsearch import Elasticsearch -from elasticsearch_dsl import A, Boolean, Date, Document, Keyword +from elasticsearch_dsl import A, Boolean, Date, Document, Keyword, Search from elasticsearch_dsl.faceted_search import ( DateHistogramFacet, FacetedSearch, @@ -57,7 +59,7 @@ class MetricSearch(FacetedSearch): @pytest.fixture(scope="session") -def commit_search_cls(es_version): +def commit_search_cls(es_version: Tuple[int, ...]) -> Type[FacetedSearch]: if es_version >= (7, 2): interval_kwargs = {"fixed_interval": "1d"} else: @@ -85,7 +87,7 @@ class CommitSearch(FacetedSearch): @pytest.fixture(scope="session") -def repo_search_cls(es_version): +def repo_search_cls(es_version: Tuple[int, ...]) -> Type[FacetedSearch]: interval_type = "calendar_interval" if es_version >= (7, 2) else "interval" class RepoSearch(FacetedSearch): @@ -98,7 +100,7 @@ class RepoSearch(FacetedSearch): ), } - def search(self): + def search(self) -> Search: s = super().search() return s.filter("term", commit_repo="repo") @@ -106,7 +108,7 @@ def search(self): @pytest.fixture(scope="session") -def pr_search_cls(es_version): +def pr_search_cls(es_version: Tuple[int, ...]) -> Type[FacetedSearch]: interval_type = "calendar_interval" if es_version >= (7, 2) else "interval" class PRSearch(FacetedSearch): @@ -125,7 +127,7 @@ class PRSearch(FacetedSearch): @pytest.mark.sync -def test_facet_with_custom_metric(data_client): +def test_facet_with_custom_metric(data_client: Elasticsearch) -> None: ms = MetricSearch() r = ms.execute() @@ -135,20 +137,24 @@ def test_facet_with_custom_metric(data_client): @pytest.mark.sync -def test_nested_facet(pull_request, pr_search_cls): +def test_nested_facet( + pull_request: PullRequest, pr_search_cls: Type[FacetedSearch] +) -> None: prs = pr_search_cls() r = prs.execute() - assert r.hits.total.value == 1 + assert r.hits.total.value == 1 # type: ignore[attr-defined] assert [(datetime(2018, 1, 1, 0, 0), 1, False)] == r.facets.comments @pytest.mark.sync -def test_nested_facet_with_filter(pull_request, pr_search_cls): +def test_nested_facet_with_filter( + pull_request: PullRequest, pr_search_cls: Type[FacetedSearch] +) -> None: prs = pr_search_cls(filters={"comments": datetime(2018, 1, 1, 0, 0)}) r = prs.execute() - assert r.hits.total.value == 1 + assert r.hits.total.value == 1 # type: ignore[attr-defined] assert [(datetime(2018, 1, 1, 0, 0), 1, True)] == r.facets.comments prs = pr_search_cls(filters={"comments": datetime(2018, 2, 1, 0, 0)}) @@ -157,31 +163,39 @@ def test_nested_facet_with_filter(pull_request, pr_search_cls): @pytest.mark.sync -def test_datehistogram_facet(data_client, repo_search_cls): +def test_datehistogram_facet( + data_client: Elasticsearch, repo_search_cls: Type[FacetedSearch] +) -> None: rs = repo_search_cls() r = rs.execute() - assert r.hits.total.value == 1 + assert r.hits.total.value == 1 # type: ignore[attr-defined] assert [(datetime(2014, 3, 1, 0, 0), 1, False)] == r.facets.created @pytest.mark.sync -def test_boolean_facet(data_client, repo_search_cls): +def test_boolean_facet( + data_client: Elasticsearch, repo_search_cls: Type[FacetedSearch] +) -> None: rs = repo_search_cls() r = rs.execute() - assert r.hits.total.value == 1 + assert r.hits.total.value == 1 # type: ignore[attr-defined] assert [(True, 1, False)] == r.facets.public value, count, selected = r.facets.public[0] assert value is True @pytest.mark.sync -def test_empty_search_finds_everything(data_client, es_version, commit_search_cls): +def test_empty_search_finds_everything( + data_client: Elasticsearch, + es_version: Tuple[int, ...], + commit_search_cls: Type[FacetedSearch], +) -> None: cs = commit_search_cls() r = cs.execute() - assert r.hits.total.value == 52 + assert r.hits.total.value == 52 # type: ignore[attr-defined] assert [ ("elasticsearch_dsl", 40, False), ("test_elasticsearch_dsl", 35, False), @@ -224,13 +238,13 @@ def test_empty_search_finds_everything(data_client, es_version, commit_search_cl @pytest.mark.sync def test_term_filters_are_shown_as_selected_and_data_is_filtered( - data_client, commit_search_cls -): + data_client: Elasticsearch, commit_search_cls: Type[FacetedSearch] +) -> None: cs = commit_search_cls(filters={"files": "test_elasticsearch_dsl"}) r = cs.execute() - assert 35 == r.hits.total.value + assert 35 == r.hits.total.value # type: ignore[attr-defined] assert [ ("elasticsearch_dsl", 40, False), ("test_elasticsearch_dsl", 35, True), # selected @@ -271,17 +285,19 @@ def test_term_filters_are_shown_as_selected_and_data_is_filtered( @pytest.mark.sync def test_range_filters_are_shown_as_selected_and_data_is_filtered( - data_client, commit_search_cls -): + data_client: Elasticsearch, commit_search_cls: Type[FacetedSearch] +) -> None: cs = commit_search_cls(filters={"deletions": "better"}) r = cs.execute() - assert 19 == r.hits.total.value + assert 19 == r.hits.total.value # type: ignore[attr-defined] @pytest.mark.sync -def test_pagination(data_client, commit_search_cls): +def test_pagination( + data_client: Elasticsearch, commit_search_cls: Type[FacetedSearch] +) -> None: cs = commit_search_cls() cs = cs[0:20] diff --git a/tests/test_integration/_sync/test_index.py b/tests/test_integration/_sync/test_index.py index 138ad18a..ff435bdf 100644 --- a/tests/test_integration/_sync/test_index.py +++ b/tests/test_integration/_sync/test_index.py @@ -16,6 +16,7 @@ # under the License. import pytest +from elasticsearch import Elasticsearch from elasticsearch_dsl import Date, Document, Index, IndexTemplate, Text, analysis @@ -26,7 +27,7 @@ class Post(Document): @pytest.mark.sync -def test_index_template_works(write_client): +def test_index_template_works(write_client: Elasticsearch) -> None: it = IndexTemplate("test-template", "test-*") it.document(Post) it.settings(number_of_replicas=0, number_of_shards=1) @@ -48,7 +49,9 @@ def test_index_template_works(write_client): @pytest.mark.sync -def test_index_can_be_saved_even_with_settings(write_client): +def test_index_can_be_saved_even_with_settings( + write_client: Elasticsearch, +) -> None: i = Index("test-blog", using=write_client) i.settings(number_of_shards=3, number_of_replicas=0) i.save() @@ -62,13 +65,15 @@ def test_index_can_be_saved_even_with_settings(write_client): @pytest.mark.sync -def test_index_exists(data_client): +def test_index_exists(data_client: Elasticsearch) -> None: assert Index("git").exists() assert not Index("not-there").exists() @pytest.mark.sync -def test_index_can_be_created_with_settings_and_mappings(write_client): +def test_index_can_be_created_with_settings_and_mappings( + write_client: Elasticsearch, +) -> None: i = Index("test-blog", using=write_client) i.document(Post) i.settings(number_of_replicas=0, number_of_shards=1) @@ -94,7 +99,7 @@ def test_index_can_be_created_with_settings_and_mappings(write_client): @pytest.mark.sync -def test_delete(write_client): +def test_delete(write_client: Elasticsearch) -> None: write_client.indices.create( index="test-index", body={"settings": {"number_of_replicas": 0, "number_of_shards": 1}}, @@ -106,7 +111,9 @@ def test_delete(write_client): @pytest.mark.sync -def test_multiple_indices_with_same_doc_type_work(write_client): +def test_multiple_indices_with_same_doc_type_work( + write_client: Elasticsearch, +) -> None: i1 = Index("test-index-1", using=write_client) i2 = Index("test-index-2", using=write_client) @@ -114,8 +121,8 @@ def test_multiple_indices_with_same_doc_type_work(write_client): i.document(Post) i.create() - for i in ("test-index-1", "test-index-2"): - settings = write_client.indices.get_settings(index=i) - assert settings[i]["settings"]["index"]["analysis"] == { + for j in ("test-index-1", "test-index-2"): + settings = write_client.indices.get_settings(index=j) + assert settings[j]["settings"]["index"]["analysis"] == { "analyzer": {"my_analyzer": {"type": "custom", "tokenizer": "keyword"}} } diff --git a/tests/test_integration/_sync/test_mapping.py b/tests/test_integration/_sync/test_mapping.py index 618d1241..20bf1821 100644 --- a/tests/test_integration/_sync/test_mapping.py +++ b/tests/test_integration/_sync/test_mapping.py @@ -16,13 +16,14 @@ # under the License. import pytest +from elasticsearch import Elasticsearch from pytest import raises from elasticsearch_dsl import Mapping, analysis, exceptions @pytest.mark.sync -def test_mapping_saved_into_es(write_client): +def test_mapping_saved_into_es(write_client: Elasticsearch) -> None: m = Mapping() m.field( "name", "text", analyzer=analysis.analyzer("my_analyzer", tokenizer="keyword") @@ -44,8 +45,8 @@ def test_mapping_saved_into_es(write_client): @pytest.mark.sync def test_mapping_saved_into_es_when_index_already_exists_closed( - write_client, -): + write_client: Elasticsearch, +) -> None: m = Mapping() m.field( "name", "text", analyzer=analysis.analyzer("my_analyzer", tokenizer="keyword") @@ -70,8 +71,8 @@ def test_mapping_saved_into_es_when_index_already_exists_closed( @pytest.mark.sync def test_mapping_saved_into_es_when_index_already_exists_with_analysis( - write_client, -): + write_client: Elasticsearch, +) -> None: m = Mapping() analyzer = analysis.analyzer("my_analyzer", tokenizer="keyword") m.field("name", "text", analyzer=analyzer) @@ -101,7 +102,9 @@ def test_mapping_saved_into_es_when_index_already_exists_with_analysis( @pytest.mark.sync -def test_mapping_gets_updated_from_es(write_client): +def test_mapping_gets_updated_from_es( + write_client: Elasticsearch, +) -> None: write_client.indices.create( index="test-mapping", body={ @@ -134,7 +137,7 @@ def test_mapping_gets_updated_from_es(write_client): m = Mapping.from_es("test-mapping", using=write_client) assert ["comments", "created_at", "title"] == list( - sorted(m.properties.properties._d_.keys()) + sorted(m.properties.properties._d_.keys()) # type: ignore[attr-defined] ) assert { "date_detection": False, diff --git a/tests/test_integration/_sync/test_search.py b/tests/test_integration/_sync/test_search.py index db1f23bf..18ed8566 100644 --- a/tests/test_integration/_sync/test_search.py +++ b/tests/test_integration/_sync/test_search.py @@ -17,7 +17,7 @@ import pytest -from elasticsearch import ApiError +from elasticsearch import ApiError, Elasticsearch from pytest import raises from elasticsearch_dsl import Date, Document, Keyword, MultiSearch, Q, Search, Text @@ -32,7 +32,7 @@ class Repository(Document): tags = Keyword() @classmethod - def search(cls): + def search(cls) -> Search["Repository"]: # type: ignore[override] return super().search().filter("term", commit_repo="repo") class Index: @@ -45,7 +45,9 @@ class Index: @pytest.mark.sync -def test_filters_aggregation_buckets_are_accessible(data_client): +def test_filters_aggregation_buckets_are_accessible( + data_client: Elasticsearch, +) -> None: has_tests_query = Q("term", files="test_elasticsearch_dsl") s = Commit.search()[0:0] s.aggs.bucket("top_authors", "terms", field="author.name.raw").bucket( @@ -68,7 +70,9 @@ def test_filters_aggregation_buckets_are_accessible(data_client): @pytest.mark.sync -def test_top_hits_are_wrapped_in_response(data_client): +def test_top_hits_are_wrapped_in_response( + data_client: Elasticsearch, +) -> None: s = Commit.search()[0:0] s.aggs.bucket("top_authors", "terms", field="author.name.raw").metric( "top_commits", "top_hits", size=5 @@ -85,7 +89,9 @@ def test_top_hits_are_wrapped_in_response(data_client): @pytest.mark.sync -def test_inner_hits_are_wrapped_in_response(data_client): +def test_inner_hits_are_wrapped_in_response( + data_client: Elasticsearch, +) -> None: s = Search(index="git")[0:1].query( "has_parent", parent_type="repo", inner_hits={}, query=Q("match_all") ) @@ -99,7 +105,7 @@ def test_inner_hits_are_wrapped_in_response(data_client): @pytest.mark.sync -def test_scan_respects_doc_types(data_client): +def test_scan_respects_doc_types(data_client: Elasticsearch) -> None: repos = [repo for repo in Repository.search().scan()] assert 1 == len(repos) @@ -108,7 +114,9 @@ def test_scan_respects_doc_types(data_client): @pytest.mark.sync -def test_scan_iterates_through_all_docs(data_client): +def test_scan_iterates_through_all_docs( + data_client: Elasticsearch, +) -> None: s = Search(index="flat-git") commits = [commit for commit in s.scan()] @@ -118,7 +126,7 @@ def test_scan_iterates_through_all_docs(data_client): @pytest.mark.sync -def test_search_after(data_client): +def test_search_after(data_client: Elasticsearch) -> None: page_size = 7 s = Search(index="flat-git")[:page_size].sort("authored_date") commits = [] @@ -127,14 +135,14 @@ def test_search_after(data_client): commits += r.hits if len(r.hits) < page_size: break - s = r.search_after() + s = s.search_after() assert 52 == len(commits) assert {d["_id"] for d in FLAT_DATA} == {c.meta.id for c in commits} @pytest.mark.sync -def test_search_after_no_search(data_client): +def test_search_after_no_search(data_client: Elasticsearch) -> None: s = Search(index="flat-git") with raises( ValueError, match="A search must be executed before using search_after" @@ -148,7 +156,7 @@ def test_search_after_no_search(data_client): @pytest.mark.sync -def test_search_after_no_sort(data_client): +def test_search_after_no_sort(data_client: Elasticsearch) -> None: s = Search(index="flat-git") r = s.execute() with raises( @@ -158,11 +166,11 @@ def test_search_after_no_sort(data_client): @pytest.mark.sync -def test_search_after_no_results(data_client): +def test_search_after_no_results(data_client: Elasticsearch) -> None: s = Search(index="flat-git")[:100].sort("authored_date") r = s.execute() assert 52 == len(r.hits) - s = r.search_after() + s = s.search_after() r = s.execute() assert 0 == len(r.hits) with raises( @@ -172,7 +180,7 @@ def test_search_after_no_results(data_client): @pytest.mark.sync -def test_point_in_time(data_client): +def test_point_in_time(data_client: Elasticsearch) -> None: page_size = 7 commits = [] with Search(index="flat-git")[:page_size].point_in_time(keep_alive="30s") as s: @@ -182,7 +190,7 @@ def test_point_in_time(data_client): commits += r.hits if len(r.hits) < page_size: break - s = r.search_after() + s = s.search_after() assert pit_id == s._extra["pit"]["id"] assert "30s" == s._extra["pit"]["keep_alive"] @@ -191,7 +199,7 @@ def test_point_in_time(data_client): @pytest.mark.sync -def test_iterate(data_client): +def test_iterate(data_client: Elasticsearch) -> None: s = Search(index="flat-git") commits = [commit for commit in s.iterate()] @@ -201,7 +209,7 @@ def test_iterate(data_client): @pytest.mark.sync -def test_response_is_cached(data_client): +def test_response_is_cached(data_client: Elasticsearch) -> None: s = Repository.search() repos = [repo for repo in s] @@ -210,11 +218,11 @@ def test_response_is_cached(data_client): @pytest.mark.sync -def test_multi_search(data_client): +def test_multi_search(data_client: Elasticsearch) -> None: s1 = Repository.search() - s2 = Search(index="flat-git") + s2 = Search[Repository](index="flat-git") - ms = MultiSearch() + ms = MultiSearch[Repository]() ms = ms.add(s1).add(s2) r1, r2 = ms.execute() @@ -223,17 +231,17 @@ def test_multi_search(data_client): assert isinstance(r1[0], Repository) assert r1._search is s1 - assert 52 == r2.hits.total.value + assert 52 == r2.hits.total.value # type: ignore[attr-defined] assert r2._search is s2 @pytest.mark.sync -def test_multi_missing(data_client): +def test_multi_missing(data_client: Elasticsearch) -> None: s1 = Repository.search() - s2 = Search(index="flat-git") - s3 = Search(index="does_not_exist") + s2 = Search[Repository](index="flat-git") + s3 = Search[Repository](index="does_not_exist") - ms = MultiSearch() + ms = MultiSearch[Repository]() ms = ms.add(s1).add(s2).add(s3) with raises(ApiError): @@ -245,14 +253,16 @@ def test_multi_missing(data_client): assert isinstance(r1[0], Repository) assert r1._search is s1 - assert 52 == r2.hits.total.value + assert 52 == r2.hits.total.value # type: ignore[attr-defined] assert r2._search is s2 assert r3 is None @pytest.mark.sync -def test_raw_subfield_can_be_used_in_aggs(data_client): +def test_raw_subfield_can_be_used_in_aggs( + data_client: Elasticsearch, +) -> None: s = Search(index="git")[0:0] s.aggs.bucket("authors", "terms", field="author.name.raw", size=1) diff --git a/tests/test_integration/_sync/test_update_by_query.py b/tests/test_integration/_sync/test_update_by_query.py index 0751e54e..28df284d 100644 --- a/tests/test_integration/_sync/test_update_by_query.py +++ b/tests/test_integration/_sync/test_update_by_query.py @@ -16,13 +16,16 @@ # under the License. import pytest +from elasticsearch import Elasticsearch from elasticsearch_dsl import UpdateByQuery from elasticsearch_dsl.search import Q @pytest.mark.sync -def test_update_by_query_no_script(write_client, setup_ubq_tests): +def test_update_by_query_no_script( + write_client: Elasticsearch, setup_ubq_tests: str +) -> None: index = setup_ubq_tests ubq = ( @@ -42,7 +45,9 @@ def test_update_by_query_no_script(write_client, setup_ubq_tests): @pytest.mark.sync -def test_update_by_query_with_script(write_client, setup_ubq_tests: str): +def test_update_by_query_with_script( + write_client: Elasticsearch, setup_ubq_tests: str +) -> None: index = setup_ubq_tests ubq = ( @@ -60,7 +65,9 @@ def test_update_by_query_with_script(write_client, setup_ubq_tests: str): @pytest.mark.sync -def test_delete_by_query_with_script(write_client, setup_ubq_tests: str): +def test_delete_by_query_with_script( + write_client: Elasticsearch, setup_ubq_tests: str +) -> None: index = setup_ubq_tests ubq = ( diff --git a/tests/test_integration/test_count.py b/tests/test_integration/test_count.py index 4b2ed958..9f467f60 100644 --- a/tests/test_integration/test_count.py +++ b/tests/test_integration/test_count.py @@ -15,28 +15,32 @@ # specific language governing permissions and limitations # under the License. +from typing import Any + +from elasticsearch import Elasticsearch + from elasticsearch_dsl.search import Q, Search -def test_count_all(data_client): +def test_count_all(data_client: Elasticsearch) -> None: s = Search(using=data_client).index("git") assert 53 == s.count() -def test_count_prefetch(data_client, mocker): +def test_count_prefetch(data_client: Elasticsearch, mocker: Any) -> None: mocker.spy(data_client, "count") search = Search(using=data_client).index("git") search.execute() assert search.count() == 53 - assert data_client.count.call_count == 0 + assert data_client.count.call_count == 0 # type: ignore[attr-defined] - search._response.hits.total.relation = "gte" + search._response.hits.total.relation = "gte" # type: ignore[attr-defined] assert search.count() == 53 - assert data_client.count.call_count == 1 + assert data_client.count.call_count == 1 # type: ignore[attr-defined] -def test_count_filter(data_client): +def test_count_filter(data_client: Elasticsearch) -> None: s = Search(using=data_client).index("git").filter(~Q("exists", field="parent_shas")) # initial commit + repo document assert 2 == s.count() diff --git a/tests/test_integration/test_data.py b/tests/test_integration/test_data.py index c78b6301..6cccf91b 100644 --- a/tests/test_integration/test_data.py +++ b/tests/test_integration/test_data.py @@ -15,8 +15,12 @@ # specific language governing permissions and limitations # under the License. +from typing import Any, Dict -def create_flat_git_index(client, index): +from elasticsearch import Elasticsearch + + +def create_flat_git_index(client: Elasticsearch, index: str) -> None: # we will use user on several places user_mapping = { "properties": {"name": {"type": "text", "fields": {"raw": {"type": "keyword"}}}} @@ -59,7 +63,7 @@ def create_flat_git_index(client, index): ) -def create_git_index(client, index): +def create_git_index(client: Elasticsearch, index: str) -> None: # we will use user on several places user_mapping = { "properties": {"name": {"type": "text", "fields": {"raw": {"type": "keyword"}}}} @@ -1081,7 +1085,7 @@ def create_git_index(client, index): ] -def flatten_doc(d): +def flatten_doc(d: Dict[str, Any]) -> Dict[str, Any]: src = d["_source"].copy() del src["commit_repo"] return {"_index": "flat-git", "_id": d["_id"], "_source": src} @@ -1090,7 +1094,7 @@ def flatten_doc(d): FLAT_DATA = [flatten_doc(d) for d in DATA if "routing" in d] -def create_test_git_data(d): +def create_test_git_data(d: Dict[str, Any]) -> Dict[str, Any]: src = d["_source"].copy() return { "_index": "test-git", diff --git a/tests/test_integration/test_examples/_async/test_alias_migration.py b/tests/test_integration/test_examples/_async/test_alias_migration.py index 9689984b..81202706 100644 --- a/tests/test_integration/test_examples/_async/test_alias_migration.py +++ b/tests/test_integration/test_examples/_async/test_alias_migration.py @@ -16,13 +16,14 @@ # under the License. import pytest +from elasticsearch import AsyncElasticsearch from ..async_examples import alias_migration from ..async_examples.alias_migration import ALIAS, PATTERN, BlogPost, migrate @pytest.mark.asyncio -async def test_alias_migration(async_write_client): +async def test_alias_migration(async_write_client: AsyncElasticsearch) -> None: # create the index await alias_migration.setup() @@ -42,6 +43,7 @@ async def test_alias_migration(async_write_client): title="Hello World!", tags=["testing", "dummy"], content=f.read(), + published=None, ) await bp.save(refresh=True) diff --git a/tests/test_integration/test_examples/_async/test_completion.py b/tests/test_integration/test_examples/_async/test_completion.py index 6a101ca7..5b890b3d 100644 --- a/tests/test_integration/test_examples/_async/test_completion.py +++ b/tests/test_integration/test_examples/_async/test_completion.py @@ -16,15 +16,18 @@ # under the License. import pytest +from elasticsearch import AsyncElasticsearch from ..async_examples.completion import Person @pytest.mark.asyncio -async def test_person_suggests_on_all_variants_of_name(async_write_client): +async def test_person_suggests_on_all_variants_of_name( + async_write_client: AsyncElasticsearch, +) -> None: await Person.init(using=async_write_client) - await Person(name="Honza Král", popularity=42).save(refresh=True) + await Person(_id=None, name="Honza Král", popularity=42).save(refresh=True) s = Person.search().suggest("t", "kra", completion={"field": "suggest"}) response = await s.execute() diff --git a/tests/test_integration/test_examples/_async/test_composite_aggs.py b/tests/test_integration/test_examples/_async/test_composite_aggs.py index 74b1fbf7..86c88cc0 100644 --- a/tests/test_integration/test_examples/_async/test_composite_aggs.py +++ b/tests/test_integration/test_examples/_async/test_composite_aggs.py @@ -16,6 +16,7 @@ # under the License. import pytest +from elasticsearch import AsyncElasticsearch from elasticsearch_dsl import A, AsyncSearch @@ -23,7 +24,9 @@ @pytest.mark.asyncio -async def test_scan_aggs_exhausts_all_files(async_data_client): +async def test_scan_aggs_exhausts_all_files( + async_data_client: AsyncElasticsearch, +) -> None: s = AsyncSearch(index="flat-git") key_aggs = {"files": A("terms", field="files")} file_list = [f async for f in scan_aggs(s, key_aggs)] @@ -32,19 +35,23 @@ async def test_scan_aggs_exhausts_all_files(async_data_client): @pytest.mark.asyncio -async def test_scan_aggs_with_multiple_aggs(async_data_client): +async def test_scan_aggs_with_multiple_aggs( + async_data_client: AsyncElasticsearch, +) -> None: s = AsyncSearch(index="flat-git") key_aggs = [ {"files": A("terms", field="files")}, { - "months": { - "date_histogram": { - "field": "committed_date", - "calendar_interval": "month", - } - } + "months": A( + "date_histogram", field="committed_date", calendar_interval="month" + ) }, ] - file_list = [f async for f in scan_aggs(s, key_aggs)] + file_list = [ + f + async for f in scan_aggs( + s, key_aggs, {"first_seen": A("min", field="committed_date")} + ) + ] assert len(file_list) == 47 diff --git a/tests/test_integration/test_examples/_async/test_parent_child.py b/tests/test_integration/test_examples/_async/test_parent_child.py index 667ffd0b..9a1027f4 100644 --- a/tests/test_integration/test_examples/_async/test_parent_child.py +++ b/tests/test_integration/test_examples/_async/test_parent_child.py @@ -19,6 +19,7 @@ import pytest import pytest_asyncio +from elasticsearch import AsyncElasticsearch from elasticsearch_dsl import Q @@ -42,7 +43,7 @@ @pytest_asyncio.fixture -async def question(async_write_client): +async def question(async_write_client: AsyncElasticsearch) -> Question: await setup() assert await async_write_client.indices.exists_template(name="base") @@ -55,16 +56,21 @@ async def question(async_write_client): body=""" I want to use elasticsearch, how do I do it from Python? """, + created=None, + question_answer=None, + comments=[], ) await q.save() return q @pytest.mark.asyncio -async def test_comment(async_write_client, question): +async def test_comment( + async_write_client: AsyncElasticsearch, question: Question +) -> None: await question.add_comment(nick, "Just use elasticsearch-py") - q = await Question.get(1) + q = await Question.get(1) # type: ignore[arg-type] assert isinstance(q, Question) assert 1 == len(q.comments) @@ -74,7 +80,9 @@ async def test_comment(async_write_client, question): @pytest.mark.asyncio -async def test_question_answer(async_write_client, question): +async def test_question_answer( + async_write_client: AsyncElasticsearch, question: Question +) -> None: a = await question.add_answer(honza, "Just use `elasticsearch-py`!") assert isinstance(a, Answer) diff --git a/tests/test_integration/test_examples/_async/test_percolate.py b/tests/test_integration/test_examples/_async/test_percolate.py index ef335381..d1564d94 100644 --- a/tests/test_integration/test_examples/_async/test_percolate.py +++ b/tests/test_integration/test_examples/_async/test_percolate.py @@ -16,12 +16,15 @@ # under the License. import pytest +from elasticsearch import AsyncElasticsearch from ..async_examples.percolate import BlogPost, setup @pytest.mark.asyncio -async def test_post_gets_tagged_automatically(async_write_client): +async def test_post_gets_tagged_automatically( + async_write_client: AsyncElasticsearch, +) -> None: await setup() bp = BlogPost(_id=47, content="nothing about snakes here!") diff --git a/tests/test_integration/test_examples/_async/test_vectors.py b/tests/test_integration/test_examples/_async/test_vectors.py index effb48cf..dedeeadf 100644 --- a/tests/test_integration/test_examples/_async/test_vectors.py +++ b/tests/test_integration/test_examples/_async/test_vectors.py @@ -16,9 +16,11 @@ # under the License. from hashlib import md5 +from typing import Any, List, Tuple from unittest import SkipTest import pytest +from elasticsearch import AsyncElasticsearch from tests.async_sleep import sleep @@ -26,17 +28,19 @@ @pytest.mark.asyncio -async def test_vector_search(async_write_client, es_version, mocker): +async def test_vector_search( + async_write_client: AsyncElasticsearch, es_version: Tuple[int, ...], mocker: Any +) -> None: # this test only runs on Elasticsearch >= 8.11 because the example uses # a dense vector without specifying an explicit size if es_version < (8, 11): raise SkipTest("This test requires Elasticsearch 8.11 or newer") class MockModel: - def __init__(self, model): + def __init__(self, model: Any): pass - def encode(self, text): + def encode(self, text: str) -> List[float]: vector = [int(ch) for ch in md5(text.encode()).digest()] total = sum(vector) return [float(v) / total for v in vector] diff --git a/tests/test_integration/test_examples/_sync/test_alias_migration.py b/tests/test_integration/test_examples/_sync/test_alias_migration.py index b0dccaa0..59cdb372 100644 --- a/tests/test_integration/test_examples/_sync/test_alias_migration.py +++ b/tests/test_integration/test_examples/_sync/test_alias_migration.py @@ -16,13 +16,14 @@ # under the License. import pytest +from elasticsearch import Elasticsearch from ..examples import alias_migration from ..examples.alias_migration import ALIAS, PATTERN, BlogPost, migrate @pytest.mark.sync -def test_alias_migration(write_client): +def test_alias_migration(write_client: Elasticsearch) -> None: # create the index alias_migration.setup() @@ -42,6 +43,7 @@ def test_alias_migration(write_client): title="Hello World!", tags=["testing", "dummy"], content=f.read(), + published=None, ) bp.save(refresh=True) diff --git a/tests/test_integration/test_examples/_sync/test_completion.py b/tests/test_integration/test_examples/_sync/test_completion.py index e391bad5..2e922710 100644 --- a/tests/test_integration/test_examples/_sync/test_completion.py +++ b/tests/test_integration/test_examples/_sync/test_completion.py @@ -16,15 +16,18 @@ # under the License. import pytest +from elasticsearch import Elasticsearch from ..examples.completion import Person @pytest.mark.sync -def test_person_suggests_on_all_variants_of_name(write_client): +def test_person_suggests_on_all_variants_of_name( + write_client: Elasticsearch, +) -> None: Person.init(using=write_client) - Person(name="Honza Král", popularity=42).save(refresh=True) + Person(_id=None, name="Honza Král", popularity=42).save(refresh=True) s = Person.search().suggest("t", "kra", completion={"field": "suggest"}) response = s.execute() diff --git a/tests/test_integration/test_examples/_sync/test_composite_aggs.py b/tests/test_integration/test_examples/_sync/test_composite_aggs.py index bd831f4f..990987dd 100644 --- a/tests/test_integration/test_examples/_sync/test_composite_aggs.py +++ b/tests/test_integration/test_examples/_sync/test_composite_aggs.py @@ -16,6 +16,7 @@ # under the License. import pytest +from elasticsearch import Elasticsearch from elasticsearch_dsl import A, Search @@ -23,7 +24,9 @@ @pytest.mark.sync -def test_scan_aggs_exhausts_all_files(data_client): +def test_scan_aggs_exhausts_all_files( + data_client: Elasticsearch, +) -> None: s = Search(index="flat-git") key_aggs = {"files": A("terms", field="files")} file_list = [f for f in scan_aggs(s, key_aggs)] @@ -32,19 +35,23 @@ def test_scan_aggs_exhausts_all_files(data_client): @pytest.mark.sync -def test_scan_aggs_with_multiple_aggs(data_client): +def test_scan_aggs_with_multiple_aggs( + data_client: Elasticsearch, +) -> None: s = Search(index="flat-git") key_aggs = [ {"files": A("terms", field="files")}, { - "months": { - "date_histogram": { - "field": "committed_date", - "calendar_interval": "month", - } - } + "months": A( + "date_histogram", field="committed_date", calendar_interval="month" + ) }, ] - file_list = [f for f in scan_aggs(s, key_aggs)] + file_list = [ + f + for f in scan_aggs( + s, key_aggs, {"first_seen": A("min", field="committed_date")} + ) + ] assert len(file_list) == 47 diff --git a/tests/test_integration/test_examples/_sync/test_parent_child.py b/tests/test_integration/test_examples/_sync/test_parent_child.py index 4c82cdd7..dcbbde86 100644 --- a/tests/test_integration/test_examples/_sync/test_parent_child.py +++ b/tests/test_integration/test_examples/_sync/test_parent_child.py @@ -18,6 +18,7 @@ from datetime import datetime import pytest +from elasticsearch import Elasticsearch from elasticsearch_dsl import Q @@ -41,7 +42,7 @@ @pytest.fixture -def question(write_client): +def question(write_client: Elasticsearch) -> Question: setup() assert write_client.indices.exists_template(name="base") @@ -54,16 +55,19 @@ def question(write_client): body=""" I want to use elasticsearch, how do I do it from Python? """, + created=None, + question_answer=None, + comments=[], ) q.save() return q @pytest.mark.sync -def test_comment(write_client, question): +def test_comment(write_client: Elasticsearch, question: Question) -> None: question.add_comment(nick, "Just use elasticsearch-py") - q = Question.get(1) + q = Question.get(1) # type: ignore[arg-type] assert isinstance(q, Question) assert 1 == len(q.comments) @@ -73,7 +77,7 @@ def test_comment(write_client, question): @pytest.mark.sync -def test_question_answer(write_client, question): +def test_question_answer(write_client: Elasticsearch, question: Question) -> None: a = question.add_answer(honza, "Just use `elasticsearch-py`!") assert isinstance(a, Answer) diff --git a/tests/test_integration/test_examples/_sync/test_percolate.py b/tests/test_integration/test_examples/_sync/test_percolate.py index 8075a225..925d362c 100644 --- a/tests/test_integration/test_examples/_sync/test_percolate.py +++ b/tests/test_integration/test_examples/_sync/test_percolate.py @@ -16,12 +16,15 @@ # under the License. import pytest +from elasticsearch import Elasticsearch from ..examples.percolate import BlogPost, setup @pytest.mark.sync -def test_post_gets_tagged_automatically(write_client): +def test_post_gets_tagged_automatically( + write_client: Elasticsearch, +) -> None: setup() bp = BlogPost(_id=47, content="nothing about snakes here!") diff --git a/tests/test_integration/test_examples/_sync/test_vectors.py b/tests/test_integration/test_examples/_sync/test_vectors.py index 395aa408..2c0fe9ff 100644 --- a/tests/test_integration/test_examples/_sync/test_vectors.py +++ b/tests/test_integration/test_examples/_sync/test_vectors.py @@ -16,9 +16,11 @@ # under the License. from hashlib import md5 +from typing import Any, List, Tuple from unittest import SkipTest import pytest +from elasticsearch import Elasticsearch from tests.sleep import sleep @@ -26,17 +28,19 @@ @pytest.mark.sync -def test_vector_search(write_client, es_version, mocker): +def test_vector_search( + write_client: Elasticsearch, es_version: Tuple[int, ...], mocker: Any +) -> None: # this test only runs on Elasticsearch >= 8.11 because the example uses # a dense vector without specifying an explicit size if es_version < (8, 11): raise SkipTest("This test requires Elasticsearch 8.11 or newer") class MockModel: - def __init__(self, model): + def __init__(self, model: Any): pass - def encode(self, text): + def encode(self, text: str) -> List[float]: vector = [int(ch) for ch in md5(text.encode()).digest()] total = sum(vector) return [float(v) / total for v in vector] diff --git a/tests/test_package.py b/tests/test_package.py index 8f8075dc..0102c86c 100644 --- a/tests/test_package.py +++ b/tests/test_package.py @@ -18,5 +18,5 @@ import elasticsearch_dsl -def test__all__is_sorted(): +def test__all__is_sorted() -> None: assert elasticsearch_dsl.__all__ == sorted(elasticsearch_dsl.__all__) diff --git a/tests/test_result.py b/tests/test_result.py index 15e6ef7a..713166ca 100644 --- a/tests/test_result.py +++ b/tests/test_result.py @@ -17,20 +17,22 @@ import pickle from datetime import date +from typing import Any, Dict from pytest import fixture, raises from elasticsearch_dsl import Date, Document, Object, Search, response from elasticsearch_dsl.aggs import Terms from elasticsearch_dsl.response.aggs import AggResponse, Bucket, BucketData +from elasticsearch_dsl.utils import AttrDict @fixture -def agg_response(aggs_search, aggs_data): +def agg_response(aggs_search: Search, aggs_data: Dict[str, Any]) -> response.Response: return response.Response(aggs_search, aggs_data) -def test_agg_response_is_pickleable(agg_response): +def test_agg_response_is_pickleable(agg_response: response.Response) -> None: agg_response.hits r = pickle.loads(pickle.dumps(agg_response)) @@ -39,8 +41,8 @@ def test_agg_response_is_pickleable(agg_response): assert r.hits == agg_response.hits -def test_response_is_pickleable(dummy_response): - res = response.Response(Search(), dummy_response.body) +def test_response_is_pickleable(dummy_response: Dict[str, Any]) -> None: + res = response.Response(Search(), dummy_response.body) # type: ignore[attr-defined] res.hits r = pickle.loads(pickle.dumps(res)) @@ -49,7 +51,7 @@ def test_response_is_pickleable(dummy_response): assert r.hits == res.hits -def test_hit_is_pickleable(dummy_response): +def test_hit_is_pickleable(dummy_response: Dict[str, Any]) -> None: res = response.Response(Search(), dummy_response) hits = pickle.loads(pickle.dumps(res.hits)) @@ -57,15 +59,15 @@ def test_hit_is_pickleable(dummy_response): assert hits[0].meta == res.hits[0].meta -def test_response_stores_search(dummy_response): +def test_response_stores_search(dummy_response: Dict[str, Any]) -> None: s = Search() r = response.Response(s, dummy_response) assert r._search is s -def test_attribute_error_in_hits_is_not_hidden(dummy_response): - def f(hit): +def test_attribute_error_in_hits_is_not_hidden(dummy_response: Dict[str, Any]) -> None: + def f(hit: AttrDict[Any]) -> Any: raise AttributeError() s = Search().doc_type(employee=f) @@ -74,7 +76,7 @@ def f(hit): r.hits -def test_interactive_helpers(dummy_response): +def test_interactive_helpers(dummy_response: Dict[str, Any]) -> None: res = response.Response(Search(), dummy_response) hits = res.hits h = hits[0] @@ -97,19 +99,19 @@ def test_interactive_helpers(dummy_response): ] == repr(h) -def test_empty_response_is_false(dummy_response): +def test_empty_response_is_false(dummy_response: Dict[str, Any]) -> None: dummy_response["hits"]["hits"] = [] res = response.Response(Search(), dummy_response) assert not res -def test_len_response(dummy_response): +def test_len_response(dummy_response: Dict[str, Any]) -> None: res = response.Response(Search(), dummy_response) assert len(res) == 4 -def test_iterating_over_response_gives_you_hits(dummy_response): +def test_iterating_over_response_gives_you_hits(dummy_response: Dict[str, Any]) -> None: res = response.Response(Search(), dummy_response) hits = list(h for h in res) @@ -127,15 +129,19 @@ def test_iterating_over_response_gives_you_hits(dummy_response): assert hits[1].meta.routing == "elasticsearch" -def test_hits_get_wrapped_to_contain_additional_attrs(dummy_response): +def test_hits_get_wrapped_to_contain_additional_attrs( + dummy_response: Dict[str, Any] +) -> None: res = response.Response(Search(), dummy_response) hits = res.hits - assert 123 == hits.total - assert 12.0 == hits.max_score + assert 123 == hits.total # type: ignore[attr-defined] + assert 12.0 == hits.max_score # type: ignore[attr-defined] -def test_hits_provide_dot_and_bracket_access_to_attrs(dummy_response): +def test_hits_provide_dot_and_bracket_access_to_attrs( + dummy_response: Dict[str, Any] +) -> None: res = response.Response(Search(), dummy_response) h = res.hits[0] @@ -151,30 +157,32 @@ def test_hits_provide_dot_and_bracket_access_to_attrs(dummy_response): h.not_there -def test_slicing_on_response_slices_on_hits(dummy_response): +def test_slicing_on_response_slices_on_hits(dummy_response: Dict[str, Any]) -> None: res = response.Response(Search(), dummy_response) assert res[0] is res.hits[0] assert res[::-1] == res.hits[::-1] -def test_aggregation_base(agg_response): +def test_aggregation_base(agg_response: response.Response) -> None: assert agg_response.aggs is agg_response.aggregations assert isinstance(agg_response.aggs, response.AggResponse) -def test_metric_agg_works(agg_response): +def test_metric_agg_works(agg_response: response.Response) -> None: assert 25052.0 == agg_response.aggs.sum_lines.value -def test_aggregations_can_be_iterated_over(agg_response): +def test_aggregations_can_be_iterated_over(agg_response: response.Response) -> None: aggs = [a for a in agg_response.aggs] assert len(aggs) == 3 assert all(map(lambda a: isinstance(a, AggResponse), aggs)) -def test_aggregations_can_be_retrieved_by_name(agg_response, aggs_search): +def test_aggregations_can_be_retrieved_by_name( + agg_response: response.Response, aggs_search: Search +) -> None: a = agg_response.aggs["popular_files"] assert isinstance(a, BucketData) @@ -182,7 +190,7 @@ def test_aggregations_can_be_retrieved_by_name(agg_response, aggs_search): assert a._meta["aggs"] is aggs_search.aggs.aggs["popular_files"] -def test_bucket_response_can_be_iterated_over(agg_response): +def test_bucket_response_can_be_iterated_over(agg_response: response.Response) -> None: popular_files = agg_response.aggregations.popular_files buckets = [b for b in popular_files] @@ -190,7 +198,9 @@ def test_bucket_response_can_be_iterated_over(agg_response): assert buckets == popular_files.buckets -def test_bucket_keys_get_deserialized(aggs_data, aggs_search): +def test_bucket_keys_get_deserialized( + aggs_data: Dict[str, Any], aggs_search: Search +) -> None: class Commit(Document): info = Object(properties={"committed_date": Date()}) diff --git a/tests/test_utils.py b/tests/test_utils.py index 8771b0c5..7c8bf232 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -31,7 +31,7 @@ def test_attrdict_pickle() -> None: def test_attrlist_pickle() -> None: - al = utils.AttrList([]) + al = utils.AttrList[Any]([]) pickled_al = pickle.dumps(al) assert al == pickle.loads(pickled_al) @@ -41,7 +41,7 @@ def test_attrlist_slice() -> None: class MyAttrDict(utils.AttrDict[str]): pass - l = utils.AttrList([{}, {}], obj_wrapper=MyAttrDict) + l = utils.AttrList[Any]([{}, {}], obj_wrapper=MyAttrDict) assert isinstance(l[:][0], MyAttrDict) @@ -111,6 +111,6 @@ def test_recursive_to_dict() -> None: def test_attrlist_to_list() -> None: - l = utils.AttrList([{}, {}]).to_list() + l = utils.AttrList[Any]([{}, {}]).to_list() assert isinstance(l, list) assert l == [{}, {}] diff --git a/tests/test_validation.py b/tests/test_validation.py index 9368b82d..ae54f50c 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -16,11 +16,11 @@ # under the License. from datetime import datetime +from typing import Any from pytest import raises from elasticsearch_dsl import ( - Boolean, Date, Document, InnerDoc, @@ -28,15 +28,16 @@ Nested, Object, Text, + mapped_field, ) from elasticsearch_dsl.exceptions import ValidationException class Author(InnerDoc): - name = Text(required=True) - email = Text(required=True) + name: str + email: str - def clean(self): + def clean(self) -> None: if not self.name: raise ValidationException("name is missing") if not self.email: @@ -52,11 +53,11 @@ class BlogPost(Document): class BlogPostWithStatus(Document): - published = Boolean(required=True) + published: bool = mapped_field(init=False) class AutoNowDate(Date): - def clean(self, data): + def clean(self, data: Any) -> Any: if data is None: data = datetime.now() return super().clean(data) @@ -67,15 +68,15 @@ class Log(Document): data = Text() -def test_required_int_can_be_0(): +def test_required_int_can_be_0() -> None: class DT(Document): i = Integer(required=True) dt = DT(i=0) - assert dt.full_clean() is None + dt.full_clean() -def test_required_field_cannot_be_empty_list(): +def test_required_field_cannot_be_empty_list() -> None: class DT(Document): i = Integer(required=True) @@ -84,7 +85,7 @@ class DT(Document): dt.full_clean() -def test_validation_works_for_lists_of_values(): +def test_validation_works_for_lists_of_values() -> None: class DT(Document): i = Date(required=True) @@ -93,24 +94,24 @@ class DT(Document): dt.full_clean() dt = DT(i=[datetime.now(), datetime.now()]) - assert None is dt.full_clean() + dt.full_clean() -def test_field_with_custom_clean(): +def test_field_with_custom_clean() -> None: l = Log() l.full_clean() assert isinstance(l.timestamp, datetime) -def test_empty_object(): +def test_empty_object() -> None: d = BlogPost(authors=[{"name": "Honza", "email": "honza@elastic.co"}]) - d.inner = {} + d.inner = {} # type: ignore[assignment] d.full_clean() -def test_missing_required_field_raises_validation_exception(): +def test_missing_required_field_raises_validation_exception() -> None: d = BlogPost() with raises(ValidationException): d.full_clean() @@ -125,7 +126,7 @@ def test_missing_required_field_raises_validation_exception(): d.full_clean() -def test_boolean_doesnt_treat_false_as_empty(): +def test_boolean_doesnt_treat_false_as_empty() -> None: d = BlogPostWithStatus() with raises(ValidationException): d.full_clean() @@ -135,26 +136,26 @@ def test_boolean_doesnt_treat_false_as_empty(): d.full_clean() -def test_custom_validation_on_nested_gets_run(): +def test_custom_validation_on_nested_gets_run() -> None: d = BlogPost(authors=[Author(name="Honza", email="king@example.com")], created=None) - assert isinstance(d.authors[0], Author) + assert isinstance(d.authors[0], Author) # type: ignore[index] with raises(ValidationException): d.full_clean() -def test_accessing_known_fields_returns_empty_value(): +def test_accessing_known_fields_returns_empty_value() -> None: d = BlogPost() assert [] == d.authors d.authors.append({}) - assert None is d.authors[0].name + assert None is d.authors[0].name # type: ignore[index] assert None is d.authors[0].email -def test_empty_values_are_not_serialized(): +def test_empty_values_are_not_serialized() -> None: d = BlogPost(authors=[{"name": "Honza", "email": "honza@elastic.co"}], created=None) d.full_clean() diff --git a/utils/run-unasync.py b/utils/run-unasync.py index e2991651..64149eb0 100644 --- a/utils/run-unasync.py +++ b/utils/run-unasync.py @@ -40,7 +40,7 @@ def main(check=False): "tests/test_integration/_sync/", ), ( - "tests/test_integration/test_examples/_async", + "tests/test_integration/test_examples/_async/", "tests/test_integration/test_examples/_sync/", ), ("examples/async/", "examples/"),