From 90f43caf0c3d542687663060624b0d296d77b86e Mon Sep 17 00:00:00 2001 From: Caio Fontes Date: Thu, 6 Jun 2024 20:27:16 +0200 Subject: [PATCH 1/5] refactor: add type hints to wrappers.py --- elasticsearch_dsl/utils.py | 26 ++++++++--- elasticsearch_dsl/wrappers.py | 84 ++++++++++++++++++++++++++++++----- noxfile.py | 2 + tests/test_wrappers.py | 26 ++++++++--- 4 files changed, 116 insertions(+), 22 deletions(-) diff --git a/elasticsearch_dsl/utils.py b/elasticsearch_dsl/utils.py index 6e311316d..7fd0b08c7 100644 --- a/elasticsearch_dsl/utils.py +++ b/elasticsearch_dsl/utils.py @@ -18,13 +18,25 @@ import collections.abc from copy import copy -from typing import Any, ClassVar, Dict, List, Optional, Type, Union +from typing import Any, ClassVar, Dict, Generic, List, Optional, Type, TypeVar, Union -from typing_extensions import Self +from typing_extensions import Self, TypeAlias from .exceptions import UnknownDslObject, ValidationException -JSONType = Union[int, bool, str, float, List["JSONType"], Dict[str, "JSONType"]] +# Usefull types + +JSONType: TypeAlias = Union[ + int, bool, str, float, List["JSONType"], Dict[str, "JSONType"] +] + + +# Type variables for internals + +_KeyT = TypeVar("_KeyT") +_ValT = TypeVar("_ValT") + +# Constants SKIP_VALUES = ("", None) EXPAND__TO_DOT = True @@ -110,18 +122,20 @@ def to_list(self): return self._l_ -class AttrDict: +class AttrDict(Generic[_KeyT, _ValT]): """ Helper class to provide attribute like access (read and write) to dictionaries. Used to provide a convenient way to access both results and nested dsl dicts. """ - def __init__(self, d): + _d_: Dict[_KeyT, _ValT] + + def __init__(self, d: Dict[_KeyT, _ValT]): # assign the inner dict manually to prevent __setattr__ from firing super().__setattr__("_d_", d) - def __contains__(self, key): + def __contains__(self, key: object) -> bool: return key in self._d_ def __nonzero__(self): diff --git a/elasticsearch_dsl/wrappers.py b/elasticsearch_dsl/wrappers.py index 0dbca982f..9009d89d0 100644 --- a/elasticsearch_dsl/wrappers.py +++ b/elasticsearch_dsl/wrappers.py @@ -16,26 +16,78 @@ # under the License. import operator +from typing import ( + Any, + Callable, + ClassVar, + Dict, + Literal, + Mapping, + Optional, + Protocol, + Tuple, + TypeVar, + Union, + cast, +) + +from typing_extensions import TypeAlias from .utils import AttrDict -__all__ = ["Range"] + +class SupportsDunderLT(Protocol): + def __lt__(self, other: Any, /) -> Any: ... + + +class SupportsDunderGT(Protocol): + def __gt__(self, other: Any, /) -> Any: ... + + +class SupportsDunderLE(Protocol): + def __le__(self, other: Any, /) -> Any: ... + + +class SupportsDunderGE(Protocol): + def __ge__(self, other: Any, /) -> Any: ... -class Range(AttrDict): - OPS = { +SupportsComparison: TypeAlias = Union[ + SupportsDunderLE, SupportsDunderGE, SupportsDunderGT, SupportsDunderLT +] + +ComparisonOperators: TypeAlias = Literal["lt", "lte", "gt", "gte"] +RangeValT = TypeVar("RangeValT", bound=SupportsComparison) + +__all__ = ["Range", "SupportsComparison"] + + +class Range(AttrDict[ComparisonOperators, RangeValT]): + OPS: ClassVar[ + Mapping[ + ComparisonOperators, + Callable[[SupportsComparison, SupportsComparison], bool], + ] + ] = { "lt": operator.lt, "lte": operator.le, "gt": operator.gt, "gte": operator.ge, } - def __init__(self, *args, **kwargs): - if args and (len(args) > 1 or kwargs or not isinstance(args[0], dict)): + def __init__( + self, + d: Optional[Dict[ComparisonOperators, RangeValT]] = None, + /, + **kwargs: RangeValT, + ): + if d is not None and (kwargs or not isinstance(d, dict)): raise ValueError( "Range accepts a single dictionary or a set of keyword arguments." ) - data = args[0] if args else kwargs + + # Cast here since mypy is inferring d as an `object` type for some reason + data = cast(Dict[str, RangeValT], d) if d is not None else kwargs for k in data: if k not in self.OPS: @@ -47,22 +99,32 @@ def __init__(self, *args, **kwargs): if "lt" in data and "lte" in data: raise ValueError("You cannot specify both lt and lte for Range.") - super().__init__(args[0] if args else kwargs) + # Here we use cast() since we now the keys are in the allowed values, but mypy does + # not infer it. + super().__init__(cast(Dict[ComparisonOperators, RangeValT], data)) - def __repr__(self): + def __repr__(self) -> str: return "Range(%s)" % ", ".join("%s=%r" % op for op in self._d_.items()) - def __contains__(self, item): + def __contains__(self, item: object) -> bool: if isinstance(item, str): return super().__contains__(item) + item_supports_comp = any(hasattr(item, f"__{op}__") for op in self.OPS) + if not item_supports_comp: + return False + + # Cast to tell mypy whe have checked it and its ok to use the comparison methods + # on `item` + item = cast(SupportsComparison, item) + for op in self.OPS: if op in self._d_ and not self.OPS[op](item, self._d_[op]): return False return True @property - def upper(self): + def upper(self) -> Union[Tuple[RangeValT, bool], Tuple[None, Literal[False]]]: if "lt" in self._d_: return self._d_["lt"], False if "lte" in self._d_: @@ -70,7 +132,7 @@ def upper(self): return None, False @property - def lower(self): + def lower(self) -> Union[Tuple[RangeValT, bool], Tuple[None, Literal[False]]]: if "gt" in self._d_: return self._d_["gt"], False if "gte" in self._d_: diff --git a/noxfile.py b/noxfile.py index f90f22f04..a731f99b9 100644 --- a/noxfile.py +++ b/noxfile.py @@ -32,7 +32,9 @@ TYPED_FILES = ( "elasticsearch_dsl/function.py", "elasticsearch_dsl/query.py", + "elasticsearch_dsl/wrappers.py", "tests/test_query.py", + "tests/test_wrappers.py", ) diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index 454722711..4c8c93f41 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -16,10 +16,12 @@ # under the License. from datetime import datetime, timedelta +from typing import Any, Mapping, Optional, Sequence import pytest from elasticsearch_dsl import Range +from elasticsearch_dsl.wrappers import SupportsComparison @pytest.mark.parametrize( @@ -34,7 +36,9 @@ ({"gt": datetime.now() - timedelta(seconds=10)}, datetime.now()), ], ) -def test_range_contains(kwargs, item): +def test_range_contains( + kwargs: Mapping[str, SupportsComparison], item: SupportsComparison +) -> None: assert item in Range(**kwargs) @@ -48,7 +52,9 @@ def test_range_contains(kwargs, item): ({"lte": datetime.now() - timedelta(seconds=10)}, datetime.now()), ], ) -def test_range_not_contains(kwargs, item): +def test_range_not_contains( + kwargs: Mapping[str, SupportsComparison], item: SupportsComparison +) -> None: assert item not in Range(**kwargs) @@ -62,7 +68,9 @@ def test_range_not_contains(kwargs, item): ((), {"gt": 1, "gte": 1}), ], ) -def test_range_raises_value_error_on_wrong_params(args, kwargs): +def test_range_raises_value_error_on_wrong_params( + args: Sequence[Any], kwargs: Mapping[str, SupportsComparison] +) -> None: with pytest.raises(ValueError): Range(*args, **kwargs) @@ -76,7 +84,11 @@ def test_range_raises_value_error_on_wrong_params(args, kwargs): (Range(lt=42), None, False), ], ) -def test_range_lower(range, lower, inclusive): +def test_range_lower( + range: Range[SupportsComparison], + lower: Optional[SupportsComparison], + inclusive: bool, +) -> None: assert (lower, inclusive) == range.lower @@ -89,5 +101,9 @@ def test_range_lower(range, lower, inclusive): (Range(gt=42), None, False), ], ) -def test_range_upper(range, upper, inclusive): +def test_range_upper( + range: Range[SupportsComparison], + upper: Optional[SupportsComparison], + inclusive: bool, +) -> None: assert (upper, inclusive) == range.upper From c407c2dd3d4716aae89a16b824fcb41ed32a24cb Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Mon, 10 Jun 2024 12:53:22 +0100 Subject: [PATCH 2/5] use _SupportsComparison type from typeshed --- elasticsearch_dsl/wrappers.py | 35 ++++++++--------------------------- tests/test_wrappers.py | 20 +++++++++++--------- 2 files changed, 19 insertions(+), 36 deletions(-) diff --git a/elasticsearch_dsl/wrappers.py b/elasticsearch_dsl/wrappers.py index 9009d89d0..83d85b98c 100644 --- a/elasticsearch_dsl/wrappers.py +++ b/elasticsearch_dsl/wrappers.py @@ -17,56 +17,37 @@ import operator from typing import ( - Any, + TYPE_CHECKING, Callable, ClassVar, Dict, Literal, Mapping, Optional, - Protocol, Tuple, TypeVar, Union, cast, ) +if TYPE_CHECKING: + from _operator import _SupportsComparison + from typing_extensions import TypeAlias from .utils import AttrDict - -class SupportsDunderLT(Protocol): - def __lt__(self, other: Any, /) -> Any: ... - - -class SupportsDunderGT(Protocol): - def __gt__(self, other: Any, /) -> Any: ... - - -class SupportsDunderLE(Protocol): - def __le__(self, other: Any, /) -> Any: ... - - -class SupportsDunderGE(Protocol): - def __ge__(self, other: Any, /) -> Any: ... - - -SupportsComparison: TypeAlias = Union[ - SupportsDunderLE, SupportsDunderGE, SupportsDunderGT, SupportsDunderLT -] - ComparisonOperators: TypeAlias = Literal["lt", "lte", "gt", "gte"] -RangeValT = TypeVar("RangeValT", bound=SupportsComparison) +RangeValT = TypeVar("RangeValT", bound=_SupportsComparison) -__all__ = ["Range", "SupportsComparison"] +__all__ = ["Range"] class Range(AttrDict[ComparisonOperators, RangeValT]): OPS: ClassVar[ Mapping[ ComparisonOperators, - Callable[[SupportsComparison, SupportsComparison], bool], + Callable[[_SupportsComparison, _SupportsComparison], bool], ] ] = { "lt": operator.lt, @@ -116,7 +97,7 @@ def __contains__(self, item: object) -> bool: # Cast to tell mypy whe have checked it and its ok to use the comparison methods # on `item` - item = cast(SupportsComparison, item) + item = cast(_SupportsComparison, item) for op in self.OPS: if op in self._d_ and not self.OPS[op](item, self._d_[op]): diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index 4c8c93f41..f41537d67 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -16,12 +16,14 @@ # under the License. from datetime import datetime, timedelta -from typing import Any, Mapping, Optional, Sequence +from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence + +if TYPE_CHECKING: + from _operator import _SupportsComparison import pytest from elasticsearch_dsl import Range -from elasticsearch_dsl.wrappers import SupportsComparison @pytest.mark.parametrize( @@ -37,7 +39,7 @@ ], ) def test_range_contains( - kwargs: Mapping[str, SupportsComparison], item: SupportsComparison + kwargs: Mapping[str, _SupportsComparison], item: _SupportsComparison ) -> None: assert item in Range(**kwargs) @@ -53,7 +55,7 @@ def test_range_contains( ], ) def test_range_not_contains( - kwargs: Mapping[str, SupportsComparison], item: SupportsComparison + kwargs: Mapping[str, _SupportsComparison], item: _SupportsComparison ) -> None: assert item not in Range(**kwargs) @@ -69,7 +71,7 @@ def test_range_not_contains( ], ) def test_range_raises_value_error_on_wrong_params( - args: Sequence[Any], kwargs: Mapping[str, SupportsComparison] + args: Sequence[Any], kwargs: Mapping[str, _SupportsComparison] ) -> None: with pytest.raises(ValueError): Range(*args, **kwargs) @@ -85,8 +87,8 @@ def test_range_raises_value_error_on_wrong_params( ], ) def test_range_lower( - range: Range[SupportsComparison], - lower: Optional[SupportsComparison], + range: Range[_SupportsComparison], + lower: Optional[_SupportsComparison], inclusive: bool, ) -> None: assert (lower, inclusive) == range.lower @@ -102,8 +104,8 @@ def test_range_lower( ], ) def test_range_upper( - range: Range[SupportsComparison], - upper: Optional[SupportsComparison], + range: Range[_SupportsComparison], + upper: Optional[_SupportsComparison], inclusive: bool, ) -> None: assert (upper, inclusive) == range.upper From d2164c0c98077d12f5c02a49fbd08af141a84e19 Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Mon, 10 Jun 2024 14:08:47 +0100 Subject: [PATCH 3/5] escape imported types in quotes --- elasticsearch_dsl/wrappers.py | 6 +++--- tests/test_wrappers.py | 14 +++++++------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/elasticsearch_dsl/wrappers.py b/elasticsearch_dsl/wrappers.py index 83d85b98c..99222cd6a 100644 --- a/elasticsearch_dsl/wrappers.py +++ b/elasticsearch_dsl/wrappers.py @@ -38,7 +38,7 @@ from .utils import AttrDict ComparisonOperators: TypeAlias = Literal["lt", "lte", "gt", "gte"] -RangeValT = TypeVar("RangeValT", bound=_SupportsComparison) +RangeValT = TypeVar("RangeValT", bound="_SupportsComparison") __all__ = ["Range"] @@ -47,7 +47,7 @@ class Range(AttrDict[ComparisonOperators, RangeValT]): OPS: ClassVar[ Mapping[ ComparisonOperators, - Callable[[_SupportsComparison, _SupportsComparison], bool], + Callable[["_SupportsComparison", "_SupportsComparison"], bool], ] ] = { "lt": operator.lt, @@ -97,7 +97,7 @@ def __contains__(self, item: object) -> bool: # Cast to tell mypy whe have checked it and its ok to use the comparison methods # on `item` - item = cast(_SupportsComparison, item) + item = cast("_SupportsComparison", item) for op in self.OPS: if op in self._d_ and not self.OPS[op](item, self._d_[op]): diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index f41537d67..f8acd5865 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -39,7 +39,7 @@ ], ) def test_range_contains( - kwargs: Mapping[str, _SupportsComparison], item: _SupportsComparison + kwargs: Mapping[str, "_SupportsComparison"], item: "_SupportsComparison" ) -> None: assert item in Range(**kwargs) @@ -55,7 +55,7 @@ def test_range_contains( ], ) def test_range_not_contains( - kwargs: Mapping[str, _SupportsComparison], item: _SupportsComparison + kwargs: Mapping[str, "_SupportsComparison"], item: "_SupportsComparison" ) -> None: assert item not in Range(**kwargs) @@ -71,7 +71,7 @@ def test_range_not_contains( ], ) def test_range_raises_value_error_on_wrong_params( - args: Sequence[Any], kwargs: Mapping[str, _SupportsComparison] + args: Sequence[Any], kwargs: Mapping[str, "_SupportsComparison"] ) -> None: with pytest.raises(ValueError): Range(*args, **kwargs) @@ -87,8 +87,8 @@ def test_range_raises_value_error_on_wrong_params( ], ) def test_range_lower( - range: Range[_SupportsComparison], - lower: Optional[_SupportsComparison], + range: Range["_SupportsComparison"], + lower: Optional["_SupportsComparison"], inclusive: bool, ) -> None: assert (lower, inclusive) == range.lower @@ -104,8 +104,8 @@ def test_range_lower( ], ) def test_range_upper( - range: Range[_SupportsComparison], - upper: Optional[_SupportsComparison], + range: Range["_SupportsComparison"], + upper: Optional["_SupportsComparison"], inclusive: bool, ) -> None: assert (upper, inclusive) == range.upper From 8f99e3d7c632c9bc894609643b24c21c7217c654 Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Wed, 12 Jun 2024 13:59:28 +0100 Subject: [PATCH 4/5] simplify casts --- elasticsearch_dsl/wrappers.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/elasticsearch_dsl/wrappers.py b/elasticsearch_dsl/wrappers.py index 99222cd6a..e759fdc07 100644 --- a/elasticsearch_dsl/wrappers.py +++ b/elasticsearch_dsl/wrappers.py @@ -67,8 +67,10 @@ def __init__( "Range accepts a single dictionary or a set of keyword arguments." ) - # Cast here since mypy is inferring d as an `object` type for some reason - data = cast(Dict[str, RangeValT], d) if d is not None else kwargs + if d is None: + data = cast(Dict[ComparisonOperators, RangeValT], kwargs) + else: + data = d for k in data: if k not in self.OPS: @@ -80,9 +82,7 @@ def __init__( if "lt" in data and "lte" in data: raise ValueError("You cannot specify both lt and lte for Range.") - # Here we use cast() since we now the keys are in the allowed values, but mypy does - # not infer it. - super().__init__(cast(Dict[ComparisonOperators, RangeValT], data)) + super().__init__(data) def __repr__(self) -> str: return "Range(%s)" % ", ".join("%s=%r" % op for op in self._d_.items()) @@ -95,12 +95,8 @@ def __contains__(self, item: object) -> bool: if not item_supports_comp: return False - # Cast to tell mypy whe have checked it and its ok to use the comparison methods - # on `item` - item = cast("_SupportsComparison", item) - for op in self.OPS: - if op in self._d_ and not self.OPS[op](item, self._d_[op]): + if op in self._d_ and not self.OPS[op](cast("_SupportsComparison", item), self._d_[op]): return False return True From 20209b08b98a7dcd52417ae696fabbcfc5deef06 Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Wed, 12 Jun 2024 14:13:42 +0100 Subject: [PATCH 5/5] fixed linter errors --- elasticsearch_dsl/wrappers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/elasticsearch_dsl/wrappers.py b/elasticsearch_dsl/wrappers.py index e759fdc07..4b6a9db5e 100644 --- a/elasticsearch_dsl/wrappers.py +++ b/elasticsearch_dsl/wrappers.py @@ -96,7 +96,9 @@ def __contains__(self, item: object) -> bool: return False for op in self.OPS: - if op in self._d_ and not self.OPS[op](cast("_SupportsComparison", item), self._d_[op]): + if op in self._d_ and not self.OPS[op]( + cast("_SupportsComparison", item), self._d_[op] + ): return False return True