Skip to content

Commit cd9a499

Browse files
Dr-Irvtwoertwein
authored andcommitted
Allow sorted() to work on Index (pandas-dev#501)
* trying SupportsRichComparisonT * create IndexIterScalar type for sorted(Index) to work * Use List not list, and Tuple not tuple for 3.8 compat
1 parent 6634608 commit cd9a499

File tree

3 files changed

+35
-4
lines changed

3 files changed

+35
-4
lines changed

pandas-stubs/_typing.pyi

+5-2
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ ListLikeExceptSeriesAndStr = TypeVar(
165165
)
166166
ListLikeU: TypeAlias = Union[Sequence, np.ndarray, Series, Index]
167167
StrLike: TypeAlias = Union[str, np.str_]
168-
Scalar: TypeAlias = Union[
168+
IndexIterScalar: TypeAlias = Union[
169169
str,
170170
bytes,
171171
datetime.date,
@@ -174,10 +174,13 @@ Scalar: TypeAlias = Union[
174174
bool,
175175
int,
176176
float,
177-
complex,
178177
Timestamp,
179178
Timedelta,
180179
]
180+
Scalar: TypeAlias = Union[
181+
IndexIterScalar,
182+
complex,
183+
]
181184
ScalarT = TypeVar("ScalarT", bound=Scalar)
182185
# Refine the definitions below in 3.9 to use the specialized type.
183186
np_ndarray_int64: TypeAlias = npt.NDArray[np.int64]

pandas-stubs/core/indexes/base.pyi

+2-1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ from pandas._typing import (
3232
DtypeObj,
3333
FillnaOptions,
3434
HashableT,
35+
IndexIterScalar,
3536
IndexT,
3637
Label,
3738
Level,
@@ -222,7 +223,7 @@ class Index(IndexOpsMixin, PandasObject):
222223
def shape(self) -> tuple[int, ...]: ...
223224
# Extra methods from old stubs
224225
def __eq__(self, other: object) -> np_ndarray_bool: ... # type: ignore[override]
225-
def __iter__(self) -> Iterator[Scalar | tuple[Hashable, ...]]: ...
226+
def __iter__(self) -> Iterator[IndexIterScalar | tuple[Hashable, ...]]: ...
226227
def __ne__(self, other: object) -> np_ndarray_bool: ... # type: ignore[override]
227228
def __le__(self, other: Index | Scalar) -> np_ndarray_bool: ...
228229
def __ge__(self, other: Index | Scalar) -> np_ndarray_bool: ...

tests/test_indexes.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import datetime as dt
44
from typing import (
5+
TYPE_CHECKING,
56
Hashable,
67
List,
78
Tuple,
@@ -16,6 +17,9 @@
1617

1718
from pandas._typing import Scalar
1819

20+
if TYPE_CHECKING:
21+
from pandas._typing import IndexIterScalar
22+
1923
from tests import (
2024
check,
2125
pytest_warns_bounded,
@@ -92,7 +96,11 @@ def test_column_contains() -> None:
9296
def test_column_sequence() -> None:
9397
df = pd.DataFrame([1, 2, 3])
9498
col_list = list(df.columns)
95-
check(assert_type(col_list, List[Union[Scalar, Tuple[Hashable, ...]]]), list, int)
99+
check(
100+
assert_type(col_list, List[Union["IndexIterScalar", Tuple[Hashable, ...]]]),
101+
list,
102+
int,
103+
)
96104

97105

98106
def test_difference_none() -> None:
@@ -657,3 +665,22 @@ def test_interval_index_tuples():
657665
pd.IntervalIndex,
658666
pd.Interval,
659667
)
668+
669+
670+
def test_sorted_and_list() -> None:
671+
# GH 497
672+
i1 = pd.Index([3, 2, 1])
673+
check(
674+
assert_type(
675+
sorted(i1),
676+
List[Union["IndexIterScalar", Tuple[Hashable, ...]]],
677+
),
678+
list,
679+
)
680+
check(
681+
assert_type(
682+
list(i1),
683+
List[Union["IndexIterScalar", Tuple[Hashable, ...]]],
684+
),
685+
list,
686+
)

0 commit comments

Comments
 (0)