Skip to content

Commit ed85bf8

Browse files
Dr-Irvtwoertwein
authored andcommitted
Fix __getitem__() for various index types (pandas-dev#537)
* Fix __getitem__() for various index types * create TypeVar for the mixin
1 parent bdcc742 commit ed85bf8

11 files changed

+229
-68
lines changed

pandas-stubs/core/algorithms.pyi

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def unique(values: PeriodIndex) -> PeriodIndex: ... # type: ignore[misc] # pyri
2525
@overload
2626
def unique(values: CategoricalIndex) -> CategoricalIndex: ... # type: ignore[misc]
2727
@overload
28-
def unique(values: IntervalIndex[IntervalT]) -> IntervalIndex[IntervalT]: ... # type: ignore[misc]
28+
def unique(values: IntervalIndex[IntervalT]) -> IntervalIndex[IntervalT]: ...
2929
@overload
3030
def unique(values: Index) -> np.ndarray: ...
3131
@overload

pandas-stubs/core/indexes/base.pyi

+28-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ from collections.abc import (
77
)
88
from typing import (
99
ClassVar,
10+
Generic,
1011
Literal,
12+
TypeVar,
1113
overload,
1214
)
1315

@@ -27,6 +29,7 @@ from pandas.core.strings import StringMethods
2729
from typing_extensions import Never
2830

2931
from pandas._typing import (
32+
S1,
3033
T1,
3134
Dtype,
3235
DtypeArg,
@@ -48,6 +51,23 @@ class InvalidIndexError(Exception): ...
4851

4952
_str = str
5053

54+
_IndexGetitemMixinT = TypeVar("_IndexGetitemMixinT", bound=_IndexGetitemMixin)
55+
56+
class _IndexGetitemMixin(Generic[S1]):
57+
@overload
58+
def __getitem__(
59+
self: _IndexGetitemMixinT,
60+
idx: slice
61+
| np_ndarray_anyint
62+
| Sequence[int]
63+
| Index
64+
| Series[bool]
65+
| Sequence[bool]
66+
| np_ndarray_bool,
67+
) -> _IndexGetitemMixinT: ...
68+
@overload
69+
def __getitem__(self, idx: int) -> S1: ...
70+
5171
class Index(IndexOpsMixin, PandasObject):
5272
__hash__: ClassVar[None] # type: ignore[assignment]
5373
@overload
@@ -158,7 +178,7 @@ class Index(IndexOpsMixin, PandasObject):
158178
__bool__ = ...
159179
def union(self, other: list[HashableT] | Index, sort=...) -> Index: ...
160180
def intersection(self, other: list[T1] | Index, sort: bool = ...) -> Index: ...
161-
def difference(self, other: list | Index) -> Index: ...
181+
def difference(self, other: list | Index, sort: bool | None = None) -> Index: ...
162182
def symmetric_difference(
163183
self, other: list[T1] | Index, result_name=..., sort=...
164184
) -> Index: ...
@@ -191,7 +211,13 @@ class Index(IndexOpsMixin, PandasObject):
191211
@overload
192212
def __getitem__(
193213
self: IndexT,
194-
idx: slice | np_ndarray_anyint | Index | Series[bool] | np_ndarray_bool,
214+
idx: slice
215+
| np_ndarray_anyint
216+
| Sequence[int]
217+
| Index
218+
| Series[bool]
219+
| Sequence[bool]
220+
| np_ndarray_bool,
195221
) -> IndexT: ...
196222
@overload
197223
def __getitem__(self, idx: int | tuple[np_ndarray_anyint, ...]) -> Scalar: ...
+3-21
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pandas.core.indexes.extension import ExtensionIndex
2-
from pandas.core.indexes.numeric import Int64Index
2+
from pandas.core.indexes.timedeltas import TimedeltaIndex
33

44
from pandas._libs.tslibs import BaseOffset
55

@@ -10,28 +10,10 @@ class DatetimeIndexOpsMixin(ExtensionIndex):
1010
def freqstr(self) -> str | None: ...
1111
@property
1212
def is_all_dates(self) -> bool: ...
13-
@property
14-
def values(self): ...
15-
def __array_wrap__(self, result, context=...): ...
16-
def equals(self, other) -> bool: ...
17-
def __contains__(self, key): ...
18-
def sort_values(self, return_indexer: bool = ..., ascending: bool = ...): ...
19-
def take(
20-
self, indices, axis: int = ..., allow_fill: bool = ..., fill_value=..., **kwargs
21-
): ...
22-
def tolist(self) -> list: ...
2313
def min(self, axis=..., skipna: bool = ..., *args, **kwargs): ...
2414
def argmin(self, axis=..., skipna: bool = ..., *args, **kwargs): ...
2515
def max(self, axis=..., skipna: bool = ..., *args, **kwargs): ...
2616
def argmax(self, axis=..., skipna: bool = ..., *args, **kwargs): ...
27-
def isin(self, values, level=...): ...
28-
def where(self, cond, other=...): ...
29-
def shift(self, periods: int = ..., freq=...): ...
30-
def delete(self, loc): ...
17+
def __rsub__(self, other: DatetimeIndexOpsMixin) -> TimedeltaIndex: ...
3118

32-
class DatetimeTimedeltaMixin(DatetimeIndexOpsMixin, Int64Index):
33-
def difference(self, other, sort=...): ...
34-
def intersection(self, other, sort: bool = ...): ...
35-
def join(
36-
self, other, *, how: str = ..., level=..., return_indexers=..., sort=...
37-
): ...
19+
class DatetimeTimedeltaMixin(DatetimeIndexOpsMixin): ...

pandas-stubs/core/indexes/datetimes.pyi

+9-3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ from pandas import (
1717
)
1818
from pandas.core.indexes.accessors import DatetimeIndexProperties
1919
from pandas.core.indexes.api import Float64Index
20+
from pandas.core.indexes.base import _IndexGetitemMixin
2021
from pandas.core.indexes.datetimelike import DatetimeTimedeltaMixin
2122
from pandas.core.series import (
2223
TimedeltaSeries,
@@ -34,7 +35,12 @@ from pandas.core.dtypes.dtypes import DatetimeTZDtype
3435

3536
from pandas.tseries.offsets import BaseOffset
3637

37-
class DatetimeIndex(DatetimeTimedeltaMixin, DatetimeIndexProperties):
38+
# type ignore needed because of __getitem__()
39+
class DatetimeIndex( # type: ignore[misc]
40+
_IndexGetitemMixin[Timestamp],
41+
DatetimeTimedeltaMixin,
42+
DatetimeIndexProperties,
43+
):
3844
def __init__(
3945
self,
4046
data: ArrayLike | AnyArrayLike | list | tuple,
@@ -53,11 +59,11 @@ class DatetimeIndex(DatetimeTimedeltaMixin, DatetimeIndexProperties):
5359
def __reduce__(self): ...
5460
# various ignores needed for mypy, as we do want to restrict what can be used in
5561
# arithmetic for these types
56-
@overload # type: ignore[override]
62+
@overload
5763
def __add__(self, other: TimedeltaSeries) -> TimestampSeries: ...
5864
@overload
5965
def __add__(self, other: Timedelta | TimedeltaIndex) -> DatetimeIndex: ...
60-
@overload # type: ignore[override]
66+
@overload
6167
def __sub__(self, other: TimedeltaSeries) -> TimestampSeries: ...
6268
@overload
6369
def __sub__(self, other: Timedelta | TimedeltaIndex) -> DatetimeIndex: ...
+1-12
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,3 @@
1-
from typing import Literal
2-
31
from pandas.core.indexes.base import Index
42

5-
class ExtensionIndex(Index):
6-
def __iter__(self): ...
7-
def dropna(self, how: Literal["any", "all"] = ...): ...
8-
def repeat(self, repeats, axis=...): ...
9-
def take(
10-
self, indices, axis: int = ..., allow_fill: bool = ..., fill_value=..., **kwargs
11-
): ...
12-
def unique(self, level=...): ...
13-
def map(self, mapper, na_action=...): ...
14-
def astype(self, dtype, copy: bool = ...): ...
3+
class ExtensionIndex(Index): ...

pandas-stubs/core/indexes/interval.pyi

+24-15
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ from typing import (
1313
import numpy as np
1414
import pandas as pd
1515
from pandas import Index
16-
from pandas.core.indexes.extension import ExtensionIndex
1716
from pandas.core.series import (
17+
Series,
1818
TimedeltaSeries,
1919
TimestampSeries,
2020
)
@@ -32,6 +32,7 @@ from pandas._typing import (
3232
IntervalClosedType,
3333
IntervalT,
3434
Label,
35+
np_ndarray_anyint,
3536
np_ndarray_bool,
3637
npt,
3738
)
@@ -68,7 +69,7 @@ _EdgesTimedelta: TypeAlias = Union[
6869
_TimestampLike: TypeAlias = Union[pd.Timestamp, np.datetime64, dt.datetime]
6970
_TimedeltaLike: TypeAlias = Union[pd.Timedelta, np.timedelta64, dt.timedelta]
7071

71-
class IntervalIndex(IntervalMixin, ExtensionIndex, Generic[IntervalT]):
72+
class IntervalIndex(IntervalMixin, Generic[IntervalT]):
7273
def __new__(
7374
cls,
7475
data: Sequence[IntervalT],
@@ -78,10 +79,9 @@ class IntervalIndex(IntervalMixin, ExtensionIndex, Generic[IntervalT]):
7879
name: Hashable = ...,
7980
verify_integrity: bool = ...,
8081
) -> IntervalIndex[IntervalT]: ...
81-
# ignore[misc] here due to overlap, e.g., Sequence[int] and Sequence[float]
8282
@overload
8383
@classmethod
84-
def from_breaks( # type:ignore[misc]
84+
def from_breaks(
8585
cls,
8686
breaks: _EdgesInt,
8787
closed: IntervalClosedType = ...,
@@ -119,10 +119,9 @@ class IntervalIndex(IntervalMixin, ExtensionIndex, Generic[IntervalT]):
119119
copy: bool = ...,
120120
dtype: IntervalDtype | None = ...,
121121
) -> IntervalIndex[Interval[pd.Timedelta]]: ...
122-
# ignore[misc] here due to overlap, e.g., Sequence[int] and Sequence[float]
123122
@overload
124123
@classmethod
125-
def from_arrays( # type:ignore[misc]
124+
def from_arrays(
126125
cls,
127126
left: _EdgesInt,
128127
right: _EdgesInt,
@@ -250,41 +249,51 @@ class IntervalIndex(IntervalMixin, ExtensionIndex, Generic[IntervalT]):
250249
@property
251250
def length(self) -> Index: ...
252251
def get_value(self, series: ABCSeries, key): ...
252+
@overload
253+
def __getitem__(
254+
self,
255+
idx: slice
256+
| np_ndarray_anyint
257+
| Sequence[int]
258+
| Index
259+
| Series[bool]
260+
| np_ndarray_bool,
261+
) -> IntervalIndex[IntervalT]: ...
262+
@overload
263+
def __getitem__(self, idx: int) -> IntervalT: ...
253264
@property
254265
def is_all_dates(self) -> bool: ...
255-
# override is due to additional types for comparison
256-
# misc is due to overlap with object below
257-
@overload # type: ignore[override]
266+
@overload
258267
def __gt__(
259268
self, other: IntervalT | IntervalIndex[IntervalT]
260269
) -> np_ndarray_bool: ...
261270
@overload
262271
def __gt__(self, other: pd.Series[IntervalT]) -> pd.Series[bool]: ...
263-
@overload # type: ignore[override]
272+
@overload
264273
def __ge__(
265274
self, other: IntervalT | IntervalIndex[IntervalT]
266275
) -> np_ndarray_bool: ...
267276
@overload
268277
def __ge__(self, other: pd.Series[IntervalT]) -> pd.Series[bool]: ...
269-
@overload # type: ignore[override]
278+
@overload
270279
def __le__(
271280
self, other: IntervalT | IntervalIndex[IntervalT]
272281
) -> np_ndarray_bool: ...
273282
@overload
274283
def __le__(self, other: pd.Series[IntervalT]) -> pd.Series[bool]: ...
275-
@overload # type: ignore[override]
284+
@overload
276285
def __lt__(
277286
self, other: IntervalT | IntervalIndex[IntervalT]
278287
) -> np_ndarray_bool: ...
279288
@overload
280-
def __lt__(self, other: pd.Series[IntervalT]) -> bool: ... # type: ignore[misc]
281-
@overload # type: ignore[override]
289+
def __lt__(self, other: pd.Series[IntervalT]) -> bool: ...
290+
@overload
282291
def __eq__(self, other: IntervalT | IntervalIndex[IntervalT]) -> np_ndarray_bool: ... # type: ignore[misc]
283292
@overload
284293
def __eq__(self, other: pd.Series[IntervalT]) -> pd.Series[bool]: ... # type: ignore[misc]
285294
@overload
286295
def __eq__(self, other: object) -> Literal[False]: ...
287-
@overload # type: ignore[override]
296+
@overload
288297
def __ne__(self, other: IntervalT | IntervalIndex[IntervalT]) -> np_ndarray_bool: ... # type: ignore[misc]
289298
@overload
290299
def __ne__(self, other: pd.Series[IntervalT]) -> pd.Series[bool]: ... # type: ignore[misc]

pandas-stubs/core/indexes/multi.pyi

+18-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@ from collections.abc import (
33
Hashable,
44
Sequence,
55
)
6-
from typing import Literal
6+
from typing import (
7+
Literal,
8+
overload,
9+
)
710

811
import numpy as np
912
import pandas as pd
@@ -13,6 +16,7 @@ from pandas._typing import (
1316
T1,
1417
DtypeArg,
1518
HashableT,
19+
np_ndarray_anyint,
1620
np_ndarray_bool,
1721
)
1822

@@ -105,7 +109,19 @@ class MultiIndex(Index):
105109
@property
106110
def levshape(self): ...
107111
def __reduce__(self): ...
108-
def __getitem__(self, key): ...
112+
@overload # type: ignore[override]
113+
def __getitem__(
114+
self,
115+
idx: slice
116+
| np_ndarray_anyint
117+
| Sequence[int]
118+
| Index
119+
| pd.Series[bool]
120+
| Sequence[bool]
121+
| np_ndarray_bool,
122+
) -> MultiIndex: ...
123+
@overload
124+
def __getitem__(self, key: int) -> tuple: ...
109125
def take(
110126
self, indices, axis: int = ..., allow_fill: bool = ..., fill_value=..., **kwargs
111127
): ...

pandas-stubs/core/indexes/period.pyi

+24-6
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,25 @@ import numpy as np
55
import pandas as pd
66
from pandas import Index
77
from pandas.core.indexes.accessors import PeriodIndexFieldOps
8+
from pandas.core.indexes.base import _IndexGetitemMixin
89
from pandas.core.indexes.datetimelike import (
910
DatetimeIndexOpsMixin as DatetimeIndexOpsMixin,
1011
)
11-
from pandas.core.indexes.numeric import Int64Index
12-
from pandas.core.series import OffsetSeries
12+
from pandas.core.indexes.timedeltas import TimedeltaIndex
1313

1414
from pandas._libs.tslibs import (
1515
BaseOffset,
16+
NaTType,
1617
Period,
1718
)
19+
from pandas._libs.tslibs.period import _PeriodAddSub
1820

19-
class PeriodIndex(DatetimeIndexOpsMixin, Int64Index, PeriodIndexFieldOps):
21+
# type ignore needed because of __getitem__()
22+
class PeriodIndex( # type: ignore[misc]
23+
_IndexGetitemMixin[Period],
24+
DatetimeIndexOpsMixin,
25+
PeriodIndexFieldOps,
26+
):
2027
def __new__(
2128
cls,
2229
data=...,
@@ -31,11 +38,22 @@ class PeriodIndex(DatetimeIndexOpsMixin, Int64Index, PeriodIndexFieldOps):
3138
@property
3239
def values(self): ...
3340
def __contains__(self, key) -> bool: ...
34-
# Override due to supertype incompatibility which has it for NumericIndex or complex.
35-
@overload # type: ignore[override]
41+
@overload
3642
def __sub__(self, other: Period) -> Index: ...
3743
@overload
38-
def __sub__(self, other: PeriodIndex) -> OffsetSeries: ...
44+
def __sub__(self, other: PeriodIndex) -> Index: ...
45+
@overload
46+
def __sub__(self, other: _PeriodAddSub) -> PeriodIndex: ...
47+
@overload
48+
def __sub__(self, other: NaTType) -> NaTType: ...
49+
@overload
50+
def __sub__(self, other: TimedeltaIndex | pd.Timedelta) -> PeriodIndex: ...
51+
@overload # type: ignore[override]
52+
def __rsub__(self, other: Period) -> Index: ...
53+
@overload
54+
def __rsub__(self, other: PeriodIndex) -> Index: ...
55+
@overload
56+
def __rsub__(self, other: NaTType) -> NaTType: ...
3957
def __array__(self, dtype=...) -> np.ndarray: ...
4058
def __array_wrap__(self, result, context=...): ...
4159
def asof_locs(self, where, mask): ...

0 commit comments

Comments
 (0)