Skip to content

Commit 2c79b48

Browse files
Add type hints to wrappers.py (#1835)
* refactor: add type hints to wrappers.py * use _SupportsComparison type from typeshed * escape imported types in quotes * simplify casts * fixed linter errors --------- Co-authored-by: Miguel Grinberg <[email protected]>
1 parent 412708e commit 2c79b48

File tree

4 files changed

+97
-22
lines changed

4 files changed

+97
-22
lines changed

elasticsearch_dsl/utils.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,25 @@
1818

1919
import collections.abc
2020
from copy import copy
21-
from typing import Any, ClassVar, Dict, List, Optional, Type, Union
21+
from typing import Any, ClassVar, Dict, Generic, List, Optional, Type, TypeVar, Union
2222

23-
from typing_extensions import Self
23+
from typing_extensions import Self, TypeAlias
2424

2525
from .exceptions import UnknownDslObject, ValidationException
2626

27-
JSONType = Union[int, bool, str, float, List["JSONType"], Dict[str, "JSONType"]]
27+
# Usefull types
28+
29+
JSONType: TypeAlias = Union[
30+
int, bool, str, float, List["JSONType"], Dict[str, "JSONType"]
31+
]
32+
33+
34+
# Type variables for internals
35+
36+
_KeyT = TypeVar("_KeyT")
37+
_ValT = TypeVar("_ValT")
38+
39+
# Constants
2840

2941
SKIP_VALUES = ("", None)
3042
EXPAND__TO_DOT = True
@@ -110,18 +122,20 @@ def to_list(self):
110122
return self._l_
111123

112124

113-
class AttrDict:
125+
class AttrDict(Generic[_KeyT, _ValT]):
114126
"""
115127
Helper class to provide attribute like access (read and write) to
116128
dictionaries. Used to provide a convenient way to access both results and
117129
nested dsl dicts.
118130
"""
119131

120-
def __init__(self, d):
132+
_d_: Dict[_KeyT, _ValT]
133+
134+
def __init__(self, d: Dict[_KeyT, _ValT]):
121135
# assign the inner dict manually to prevent __setattr__ from firing
122136
super().__setattr__("_d_", d)
123137

124-
def __contains__(self, key):
138+
def __contains__(self, key: object) -> bool:
125139
return key in self._d_
126140

127141
def __nonzero__(self):

elasticsearch_dsl/wrappers.py

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,61 @@
1616
# under the License.
1717

1818
import operator
19+
from typing import (
20+
TYPE_CHECKING,
21+
Callable,
22+
ClassVar,
23+
Dict,
24+
Literal,
25+
Mapping,
26+
Optional,
27+
Tuple,
28+
TypeVar,
29+
Union,
30+
cast,
31+
)
32+
33+
if TYPE_CHECKING:
34+
from _operator import _SupportsComparison
35+
36+
from typing_extensions import TypeAlias
1937

2038
from .utils import AttrDict
2139

40+
ComparisonOperators: TypeAlias = Literal["lt", "lte", "gt", "gte"]
41+
RangeValT = TypeVar("RangeValT", bound="_SupportsComparison")
42+
2243
__all__ = ["Range"]
2344

2445

25-
class Range(AttrDict):
26-
OPS = {
46+
class Range(AttrDict[ComparisonOperators, RangeValT]):
47+
OPS: ClassVar[
48+
Mapping[
49+
ComparisonOperators,
50+
Callable[["_SupportsComparison", "_SupportsComparison"], bool],
51+
]
52+
] = {
2753
"lt": operator.lt,
2854
"lte": operator.le,
2955
"gt": operator.gt,
3056
"gte": operator.ge,
3157
}
3258

33-
def __init__(self, *args, **kwargs):
34-
if args and (len(args) > 1 or kwargs or not isinstance(args[0], dict)):
59+
def __init__(
60+
self,
61+
d: Optional[Dict[ComparisonOperators, RangeValT]] = None,
62+
/,
63+
**kwargs: RangeValT,
64+
):
65+
if d is not None and (kwargs or not isinstance(d, dict)):
3566
raise ValueError(
3667
"Range accepts a single dictionary or a set of keyword arguments."
3768
)
38-
data = args[0] if args else kwargs
69+
70+
if d is None:
71+
data = cast(Dict[ComparisonOperators, RangeValT], kwargs)
72+
else:
73+
data = d
3974

4075
for k in data:
4176
if k not in self.OPS:
@@ -47,30 +82,36 @@ def __init__(self, *args, **kwargs):
4782
if "lt" in data and "lte" in data:
4883
raise ValueError("You cannot specify both lt and lte for Range.")
4984

50-
super().__init__(args[0] if args else kwargs)
85+
super().__init__(data)
5186

52-
def __repr__(self):
87+
def __repr__(self) -> str:
5388
return "Range(%s)" % ", ".join("%s=%r" % op for op in self._d_.items())
5489

55-
def __contains__(self, item):
90+
def __contains__(self, item: object) -> bool:
5691
if isinstance(item, str):
5792
return super().__contains__(item)
5893

94+
item_supports_comp = any(hasattr(item, f"__{op}__") for op in self.OPS)
95+
if not item_supports_comp:
96+
return False
97+
5998
for op in self.OPS:
60-
if op in self._d_ and not self.OPS[op](item, self._d_[op]):
99+
if op in self._d_ and not self.OPS[op](
100+
cast("_SupportsComparison", item), self._d_[op]
101+
):
61102
return False
62103
return True
63104

64105
@property
65-
def upper(self):
106+
def upper(self) -> Union[Tuple[RangeValT, bool], Tuple[None, Literal[False]]]:
66107
if "lt" in self._d_:
67108
return self._d_["lt"], False
68109
if "lte" in self._d_:
69110
return self._d_["lte"], True
70111
return None, False
71112

72113
@property
73-
def lower(self):
114+
def lower(self) -> Union[Tuple[RangeValT, bool], Tuple[None, Literal[False]]]:
74115
if "gt" in self._d_:
75116
return self._d_["gt"], False
76117
if "gte" in self._d_:

noxfile.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@
3232
TYPED_FILES = (
3333
"elasticsearch_dsl/function.py",
3434
"elasticsearch_dsl/query.py",
35+
"elasticsearch_dsl/wrappers.py",
3536
"tests/test_query.py",
37+
"tests/test_wrappers.py",
3638
)
3739

3840

tests/test_wrappers.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616
# under the License.
1717

1818
from datetime import datetime, timedelta
19+
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence
20+
21+
if TYPE_CHECKING:
22+
from _operator import _SupportsComparison
1923

2024
import pytest
2125

@@ -34,7 +38,9 @@
3438
({"gt": datetime.now() - timedelta(seconds=10)}, datetime.now()),
3539
],
3640
)
37-
def test_range_contains(kwargs, item):
41+
def test_range_contains(
42+
kwargs: Mapping[str, "_SupportsComparison"], item: "_SupportsComparison"
43+
) -> None:
3844
assert item in Range(**kwargs)
3945

4046

@@ -48,7 +54,9 @@ def test_range_contains(kwargs, item):
4854
({"lte": datetime.now() - timedelta(seconds=10)}, datetime.now()),
4955
],
5056
)
51-
def test_range_not_contains(kwargs, item):
57+
def test_range_not_contains(
58+
kwargs: Mapping[str, "_SupportsComparison"], item: "_SupportsComparison"
59+
) -> None:
5260
assert item not in Range(**kwargs)
5361

5462

@@ -62,7 +70,9 @@ def test_range_not_contains(kwargs, item):
6270
((), {"gt": 1, "gte": 1}),
6371
],
6472
)
65-
def test_range_raises_value_error_on_wrong_params(args, kwargs):
73+
def test_range_raises_value_error_on_wrong_params(
74+
args: Sequence[Any], kwargs: Mapping[str, "_SupportsComparison"]
75+
) -> None:
6676
with pytest.raises(ValueError):
6777
Range(*args, **kwargs)
6878

@@ -76,7 +86,11 @@ def test_range_raises_value_error_on_wrong_params(args, kwargs):
7686
(Range(lt=42), None, False),
7787
],
7888
)
79-
def test_range_lower(range, lower, inclusive):
89+
def test_range_lower(
90+
range: Range["_SupportsComparison"],
91+
lower: Optional["_SupportsComparison"],
92+
inclusive: bool,
93+
) -> None:
8094
assert (lower, inclusive) == range.lower
8195

8296

@@ -89,5 +103,9 @@ def test_range_lower(range, lower, inclusive):
89103
(Range(gt=42), None, False),
90104
],
91105
)
92-
def test_range_upper(range, upper, inclusive):
106+
def test_range_upper(
107+
range: Range["_SupportsComparison"],
108+
upper: Optional["_SupportsComparison"],
109+
inclusive: bool,
110+
) -> None:
93111
assert (upper, inclusive) == range.upper

0 commit comments

Comments
 (0)