diff --git a/elasticsearch_dsl/utils.py b/elasticsearch_dsl/utils.py index 6e311316..7fd0b08c 100644 --- a/elasticsearch_dsl/utils.py +++ b/elasticsearch_dsl/utils.py @@ -18,13 +18,25 @@ import collections.abc from copy import copy -from typing import Any, ClassVar, Dict, List, Optional, Type, Union +from typing import Any, ClassVar, Dict, Generic, List, Optional, Type, TypeVar, Union -from typing_extensions import Self +from typing_extensions import Self, TypeAlias from .exceptions import UnknownDslObject, ValidationException -JSONType = Union[int, bool, str, float, List["JSONType"], Dict[str, "JSONType"]] +# Usefull types + +JSONType: TypeAlias = Union[ + int, bool, str, float, List["JSONType"], Dict[str, "JSONType"] +] + + +# Type variables for internals + +_KeyT = TypeVar("_KeyT") +_ValT = TypeVar("_ValT") + +# Constants SKIP_VALUES = ("", None) EXPAND__TO_DOT = True @@ -110,18 +122,20 @@ def to_list(self): return self._l_ -class AttrDict: +class AttrDict(Generic[_KeyT, _ValT]): """ Helper class to provide attribute like access (read and write) to dictionaries. Used to provide a convenient way to access both results and nested dsl dicts. """ - def __init__(self, d): + _d_: Dict[_KeyT, _ValT] + + def __init__(self, d: Dict[_KeyT, _ValT]): # assign the inner dict manually to prevent __setattr__ from firing super().__setattr__("_d_", d) - def __contains__(self, key): + def __contains__(self, key: object) -> bool: return key in self._d_ def __nonzero__(self): diff --git a/elasticsearch_dsl/wrappers.py b/elasticsearch_dsl/wrappers.py index 0dbca982..4b6a9db5 100644 --- a/elasticsearch_dsl/wrappers.py +++ b/elasticsearch_dsl/wrappers.py @@ -16,26 +16,61 @@ # under the License. import operator +from typing import ( + TYPE_CHECKING, + Callable, + ClassVar, + Dict, + Literal, + Mapping, + Optional, + Tuple, + TypeVar, + Union, + cast, +) + +if TYPE_CHECKING: + from _operator import _SupportsComparison + +from typing_extensions import TypeAlias from .utils import AttrDict +ComparisonOperators: TypeAlias = Literal["lt", "lte", "gt", "gte"] +RangeValT = TypeVar("RangeValT", bound="_SupportsComparison") + __all__ = ["Range"] -class Range(AttrDict): - OPS = { +class Range(AttrDict[ComparisonOperators, RangeValT]): + OPS: ClassVar[ + Mapping[ + ComparisonOperators, + Callable[["_SupportsComparison", "_SupportsComparison"], bool], + ] + ] = { "lt": operator.lt, "lte": operator.le, "gt": operator.gt, "gte": operator.ge, } - def __init__(self, *args, **kwargs): - if args and (len(args) > 1 or kwargs or not isinstance(args[0], dict)): + def __init__( + self, + d: Optional[Dict[ComparisonOperators, RangeValT]] = None, + /, + **kwargs: RangeValT, + ): + if d is not None and (kwargs or not isinstance(d, dict)): raise ValueError( "Range accepts a single dictionary or a set of keyword arguments." ) - data = args[0] if args else kwargs + + if d is None: + data = cast(Dict[ComparisonOperators, RangeValT], kwargs) + else: + data = d for k in data: if k not in self.OPS: @@ -47,22 +82,28 @@ def __init__(self, *args, **kwargs): if "lt" in data and "lte" in data: raise ValueError("You cannot specify both lt and lte for Range.") - super().__init__(args[0] if args else kwargs) + super().__init__(data) - def __repr__(self): + def __repr__(self) -> str: return "Range(%s)" % ", ".join("%s=%r" % op for op in self._d_.items()) - def __contains__(self, item): + def __contains__(self, item: object) -> bool: if isinstance(item, str): return super().__contains__(item) + item_supports_comp = any(hasattr(item, f"__{op}__") for op in self.OPS) + if not item_supports_comp: + return False + for op in self.OPS: - if op in self._d_ and not self.OPS[op](item, self._d_[op]): + if op in self._d_ and not self.OPS[op]( + cast("_SupportsComparison", item), self._d_[op] + ): return False return True @property - def upper(self): + def upper(self) -> Union[Tuple[RangeValT, bool], Tuple[None, Literal[False]]]: if "lt" in self._d_: return self._d_["lt"], False if "lte" in self._d_: @@ -70,7 +111,7 @@ def upper(self): return None, False @property - def lower(self): + def lower(self) -> Union[Tuple[RangeValT, bool], Tuple[None, Literal[False]]]: if "gt" in self._d_: return self._d_["gt"], False if "gte" in self._d_: diff --git a/noxfile.py b/noxfile.py index f90f22f0..a731f99b 100644 --- a/noxfile.py +++ b/noxfile.py @@ -32,7 +32,9 @@ TYPED_FILES = ( "elasticsearch_dsl/function.py", "elasticsearch_dsl/query.py", + "elasticsearch_dsl/wrappers.py", "tests/test_query.py", + "tests/test_wrappers.py", ) diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index 45472271..f8acd586 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -16,6 +16,10 @@ # under the License. from datetime import datetime, timedelta +from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence + +if TYPE_CHECKING: + from _operator import _SupportsComparison import pytest @@ -34,7 +38,9 @@ ({"gt": datetime.now() - timedelta(seconds=10)}, datetime.now()), ], ) -def test_range_contains(kwargs, item): +def test_range_contains( + kwargs: Mapping[str, "_SupportsComparison"], item: "_SupportsComparison" +) -> None: assert item in Range(**kwargs) @@ -48,7 +54,9 @@ def test_range_contains(kwargs, item): ({"lte": datetime.now() - timedelta(seconds=10)}, datetime.now()), ], ) -def test_range_not_contains(kwargs, item): +def test_range_not_contains( + kwargs: Mapping[str, "_SupportsComparison"], item: "_SupportsComparison" +) -> None: assert item not in Range(**kwargs) @@ -62,7 +70,9 @@ def test_range_not_contains(kwargs, item): ((), {"gt": 1, "gte": 1}), ], ) -def test_range_raises_value_error_on_wrong_params(args, kwargs): +def test_range_raises_value_error_on_wrong_params( + args: Sequence[Any], kwargs: Mapping[str, "_SupportsComparison"] +) -> None: with pytest.raises(ValueError): Range(*args, **kwargs) @@ -76,7 +86,11 @@ def test_range_raises_value_error_on_wrong_params(args, kwargs): (Range(lt=42), None, False), ], ) -def test_range_lower(range, lower, inclusive): +def test_range_lower( + range: Range["_SupportsComparison"], + lower: Optional["_SupportsComparison"], + inclusive: bool, +) -> None: assert (lower, inclusive) == range.lower @@ -89,5 +103,9 @@ def test_range_lower(range, lower, inclusive): (Range(gt=42), None, False), ], ) -def test_range_upper(range, upper, inclusive): +def test_range_upper( + range: Range["_SupportsComparison"], + upper: Optional["_SupportsComparison"], + inclusive: bool, +) -> None: assert (upper, inclusive) == range.upper