Skip to content

Commit c407c2d

Browse files
use _SupportsComparison type from typeshed
1 parent 3668478 commit c407c2d

File tree

2 files changed

+19
-36
lines changed

2 files changed

+19
-36
lines changed

elasticsearch_dsl/wrappers.py

Lines changed: 8 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,56 +17,37 @@
1717

1818
import operator
1919
from typing import (
20-
Any,
20+
TYPE_CHECKING,
2121
Callable,
2222
ClassVar,
2323
Dict,
2424
Literal,
2525
Mapping,
2626
Optional,
27-
Protocol,
2827
Tuple,
2928
TypeVar,
3029
Union,
3130
cast,
3231
)
3332

33+
if TYPE_CHECKING:
34+
from _operator import _SupportsComparison
35+
3436
from typing_extensions import TypeAlias
3537

3638
from .utils import AttrDict
3739

38-
39-
class SupportsDunderLT(Protocol):
40-
def __lt__(self, other: Any, /) -> Any: ...
41-
42-
43-
class SupportsDunderGT(Protocol):
44-
def __gt__(self, other: Any, /) -> Any: ...
45-
46-
47-
class SupportsDunderLE(Protocol):
48-
def __le__(self, other: Any, /) -> Any: ...
49-
50-
51-
class SupportsDunderGE(Protocol):
52-
def __ge__(self, other: Any, /) -> Any: ...
53-
54-
55-
SupportsComparison: TypeAlias = Union[
56-
SupportsDunderLE, SupportsDunderGE, SupportsDunderGT, SupportsDunderLT
57-
]
58-
5940
ComparisonOperators: TypeAlias = Literal["lt", "lte", "gt", "gte"]
60-
RangeValT = TypeVar("RangeValT", bound=SupportsComparison)
41+
RangeValT = TypeVar("RangeValT", bound=_SupportsComparison)
6142

62-
__all__ = ["Range", "SupportsComparison"]
43+
__all__ = ["Range"]
6344

6445

6546
class Range(AttrDict[ComparisonOperators, RangeValT]):
6647
OPS: ClassVar[
6748
Mapping[
6849
ComparisonOperators,
69-
Callable[[SupportsComparison, SupportsComparison], bool],
50+
Callable[[_SupportsComparison, _SupportsComparison], bool],
7051
]
7152
] = {
7253
"lt": operator.lt,
@@ -116,7 +97,7 @@ def __contains__(self, item: object) -> bool:
11697

11798
# Cast to tell mypy whe have checked it and its ok to use the comparison methods
11899
# on `item`
119-
item = cast(SupportsComparison, item)
100+
item = cast(_SupportsComparison, item)
120101

121102
for op in self.OPS:
122103
if op in self._d_ and not self.OPS[op](item, self._d_[op]):

tests/test_wrappers.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@
1616
# under the License.
1717

1818
from datetime import datetime, timedelta
19-
from typing import Any, Mapping, Optional, Sequence
19+
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence
20+
21+
if TYPE_CHECKING:
22+
from _operator import _SupportsComparison
2023

2124
import pytest
2225

2326
from elasticsearch_dsl import Range
24-
from elasticsearch_dsl.wrappers import SupportsComparison
2527

2628

2729
@pytest.mark.parametrize(
@@ -37,7 +39,7 @@
3739
],
3840
)
3941
def test_range_contains(
40-
kwargs: Mapping[str, SupportsComparison], item: SupportsComparison
42+
kwargs: Mapping[str, _SupportsComparison], item: _SupportsComparison
4143
) -> None:
4244
assert item in Range(**kwargs)
4345

@@ -53,7 +55,7 @@ def test_range_contains(
5355
],
5456
)
5557
def test_range_not_contains(
56-
kwargs: Mapping[str, SupportsComparison], item: SupportsComparison
58+
kwargs: Mapping[str, _SupportsComparison], item: _SupportsComparison
5759
) -> None:
5860
assert item not in Range(**kwargs)
5961

@@ -69,7 +71,7 @@ def test_range_not_contains(
6971
],
7072
)
7173
def test_range_raises_value_error_on_wrong_params(
72-
args: Sequence[Any], kwargs: Mapping[str, SupportsComparison]
74+
args: Sequence[Any], kwargs: Mapping[str, _SupportsComparison]
7375
) -> None:
7476
with pytest.raises(ValueError):
7577
Range(*args, **kwargs)
@@ -85,8 +87,8 @@ def test_range_raises_value_error_on_wrong_params(
8587
],
8688
)
8789
def test_range_lower(
88-
range: Range[SupportsComparison],
89-
lower: Optional[SupportsComparison],
90+
range: Range[_SupportsComparison],
91+
lower: Optional[_SupportsComparison],
9092
inclusive: bool,
9193
) -> None:
9294
assert (lower, inclusive) == range.lower
@@ -102,8 +104,8 @@ def test_range_lower(
102104
],
103105
)
104106
def test_range_upper(
105-
range: Range[SupportsComparison],
106-
upper: Optional[SupportsComparison],
107+
range: Range[_SupportsComparison],
108+
upper: Optional[_SupportsComparison],
107109
inclusive: bool,
108110
) -> None:
109111
assert (upper, inclusive) == range.upper

0 commit comments

Comments
 (0)