Skip to content

Commit 90f43ca

Browse files
committed
refactor: add type hints to wrappers.py
1 parent 76a57fd commit 90f43ca

File tree

4 files changed

+116
-22
lines changed

4 files changed

+116
-22
lines changed

elasticsearch_dsl/utils.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,25 @@
1818

1919
import collections.abc
2020
from copy import copy
21-
from typing import Any, ClassVar, Dict, List, Optional, Type, Union
21+
from typing import Any, ClassVar, Dict, Generic, List, Optional, Type, TypeVar, Union
2222

23-
from typing_extensions import Self
23+
from typing_extensions import Self, TypeAlias
2424

2525
from .exceptions import UnknownDslObject, ValidationException
2626

27-
JSONType = Union[int, bool, str, float, List["JSONType"], Dict[str, "JSONType"]]
27+
# Usefull types
28+
29+
JSONType: TypeAlias = Union[
30+
int, bool, str, float, List["JSONType"], Dict[str, "JSONType"]
31+
]
32+
33+
34+
# Type variables for internals
35+
36+
_KeyT = TypeVar("_KeyT")
37+
_ValT = TypeVar("_ValT")
38+
39+
# Constants
2840

2941
SKIP_VALUES = ("", None)
3042
EXPAND__TO_DOT = True
@@ -110,18 +122,20 @@ def to_list(self):
110122
return self._l_
111123

112124

113-
class AttrDict:
125+
class AttrDict(Generic[_KeyT, _ValT]):
114126
"""
115127
Helper class to provide attribute like access (read and write) to
116128
dictionaries. Used to provide a convenient way to access both results and
117129
nested dsl dicts.
118130
"""
119131

120-
def __init__(self, d):
132+
_d_: Dict[_KeyT, _ValT]
133+
134+
def __init__(self, d: Dict[_KeyT, _ValT]):
121135
# assign the inner dict manually to prevent __setattr__ from firing
122136
super().__setattr__("_d_", d)
123137

124-
def __contains__(self, key):
138+
def __contains__(self, key: object) -> bool:
125139
return key in self._d_
126140

127141
def __nonzero__(self):

elasticsearch_dsl/wrappers.py

Lines changed: 73 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,78 @@
1616
# under the License.
1717

1818
import operator
19+
from typing import (
20+
Any,
21+
Callable,
22+
ClassVar,
23+
Dict,
24+
Literal,
25+
Mapping,
26+
Optional,
27+
Protocol,
28+
Tuple,
29+
TypeVar,
30+
Union,
31+
cast,
32+
)
33+
34+
from typing_extensions import TypeAlias
1935

2036
from .utils import AttrDict
2137

22-
__all__ = ["Range"]
38+
39+
class SupportsDunderLT(Protocol):
40+
def __lt__(self, other: Any, /) -> Any: ...
41+
42+
43+
class SupportsDunderGT(Protocol):
44+
def __gt__(self, other: Any, /) -> Any: ...
45+
46+
47+
class SupportsDunderLE(Protocol):
48+
def __le__(self, other: Any, /) -> Any: ...
49+
50+
51+
class SupportsDunderGE(Protocol):
52+
def __ge__(self, other: Any, /) -> Any: ...
2353

2454

25-
class Range(AttrDict):
26-
OPS = {
55+
SupportsComparison: TypeAlias = Union[
56+
SupportsDunderLE, SupportsDunderGE, SupportsDunderGT, SupportsDunderLT
57+
]
58+
59+
ComparisonOperators: TypeAlias = Literal["lt", "lte", "gt", "gte"]
60+
RangeValT = TypeVar("RangeValT", bound=SupportsComparison)
61+
62+
__all__ = ["Range", "SupportsComparison"]
63+
64+
65+
class Range(AttrDict[ComparisonOperators, RangeValT]):
66+
OPS: ClassVar[
67+
Mapping[
68+
ComparisonOperators,
69+
Callable[[SupportsComparison, SupportsComparison], bool],
70+
]
71+
] = {
2772
"lt": operator.lt,
2873
"lte": operator.le,
2974
"gt": operator.gt,
3075
"gte": operator.ge,
3176
}
3277

33-
def __init__(self, *args, **kwargs):
34-
if args and (len(args) > 1 or kwargs or not isinstance(args[0], dict)):
78+
def __init__(
79+
self,
80+
d: Optional[Dict[ComparisonOperators, RangeValT]] = None,
81+
/,
82+
**kwargs: RangeValT,
83+
):
84+
if d is not None and (kwargs or not isinstance(d, dict)):
3585
raise ValueError(
3686
"Range accepts a single dictionary or a set of keyword arguments."
3787
)
38-
data = args[0] if args else kwargs
88+
89+
# Cast here since mypy is inferring d as an `object` type for some reason
90+
data = cast(Dict[str, RangeValT], d) if d is not None else kwargs
3991

4092
for k in data:
4193
if k not in self.OPS:
@@ -47,30 +99,40 @@ def __init__(self, *args, **kwargs):
4799
if "lt" in data and "lte" in data:
48100
raise ValueError("You cannot specify both lt and lte for Range.")
49101

50-
super().__init__(args[0] if args else kwargs)
102+
# Here we use cast() since we now the keys are in the allowed values, but mypy does
103+
# not infer it.
104+
super().__init__(cast(Dict[ComparisonOperators, RangeValT], data))
51105

52-
def __repr__(self):
106+
def __repr__(self) -> str:
53107
return "Range(%s)" % ", ".join("%s=%r" % op for op in self._d_.items())
54108

55-
def __contains__(self, item):
109+
def __contains__(self, item: object) -> bool:
56110
if isinstance(item, str):
57111
return super().__contains__(item)
58112

113+
item_supports_comp = any(hasattr(item, f"__{op}__") for op in self.OPS)
114+
if not item_supports_comp:
115+
return False
116+
117+
# Cast to tell mypy whe have checked it and its ok to use the comparison methods
118+
# on `item`
119+
item = cast(SupportsComparison, item)
120+
59121
for op in self.OPS:
60122
if op in self._d_ and not self.OPS[op](item, self._d_[op]):
61123
return False
62124
return True
63125

64126
@property
65-
def upper(self):
127+
def upper(self) -> Union[Tuple[RangeValT, bool], Tuple[None, Literal[False]]]:
66128
if "lt" in self._d_:
67129
return self._d_["lt"], False
68130
if "lte" in self._d_:
69131
return self._d_["lte"], True
70132
return None, False
71133

72134
@property
73-
def lower(self):
135+
def lower(self) -> Union[Tuple[RangeValT, bool], Tuple[None, Literal[False]]]:
74136
if "gt" in self._d_:
75137
return self._d_["gt"], False
76138
if "gte" in self._d_:

noxfile.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@
3232
TYPED_FILES = (
3333
"elasticsearch_dsl/function.py",
3434
"elasticsearch_dsl/query.py",
35+
"elasticsearch_dsl/wrappers.py",
3536
"tests/test_query.py",
37+
"tests/test_wrappers.py",
3638
)
3739

3840

tests/test_wrappers.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616
# under the License.
1717

1818
from datetime import datetime, timedelta
19+
from typing import Any, Mapping, Optional, Sequence
1920

2021
import pytest
2122

2223
from elasticsearch_dsl import Range
24+
from elasticsearch_dsl.wrappers import SupportsComparison
2325

2426

2527
@pytest.mark.parametrize(
@@ -34,7 +36,9 @@
3436
({"gt": datetime.now() - timedelta(seconds=10)}, datetime.now()),
3537
],
3638
)
37-
def test_range_contains(kwargs, item):
39+
def test_range_contains(
40+
kwargs: Mapping[str, SupportsComparison], item: SupportsComparison
41+
) -> None:
3842
assert item in Range(**kwargs)
3943

4044

@@ -48,7 +52,9 @@ def test_range_contains(kwargs, item):
4852
({"lte": datetime.now() - timedelta(seconds=10)}, datetime.now()),
4953
],
5054
)
51-
def test_range_not_contains(kwargs, item):
55+
def test_range_not_contains(
56+
kwargs: Mapping[str, SupportsComparison], item: SupportsComparison
57+
) -> None:
5258
assert item not in Range(**kwargs)
5359

5460

@@ -62,7 +68,9 @@ def test_range_not_contains(kwargs, item):
6268
((), {"gt": 1, "gte": 1}),
6369
],
6470
)
65-
def test_range_raises_value_error_on_wrong_params(args, kwargs):
71+
def test_range_raises_value_error_on_wrong_params(
72+
args: Sequence[Any], kwargs: Mapping[str, SupportsComparison]
73+
) -> None:
6674
with pytest.raises(ValueError):
6775
Range(*args, **kwargs)
6876

@@ -76,7 +84,11 @@ def test_range_raises_value_error_on_wrong_params(args, kwargs):
7684
(Range(lt=42), None, False),
7785
],
7886
)
79-
def test_range_lower(range, lower, inclusive):
87+
def test_range_lower(
88+
range: Range[SupportsComparison],
89+
lower: Optional[SupportsComparison],
90+
inclusive: bool,
91+
) -> None:
8092
assert (lower, inclusive) == range.lower
8193

8294

@@ -89,5 +101,9 @@ def test_range_lower(range, lower, inclusive):
89101
(Range(gt=42), None, False),
90102
],
91103
)
92-
def test_range_upper(range, upper, inclusive):
104+
def test_range_upper(
105+
range: Range[SupportsComparison],
106+
upper: Optional[SupportsComparison],
107+
inclusive: bool,
108+
) -> None:
93109
assert (upper, inclusive) == range.upper

0 commit comments

Comments
 (0)