Skip to content

Commit bc94e28

Browse files
authored
TYP: Typing for ExtensionArray.__getitem__ (#41258)
1 parent b17379b commit bc94e28

17 files changed

+215
-79
lines changed

pandas/_typing.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -206,10 +206,16 @@
206206
# indexing
207207
# PositionalIndexer -> valid 1D positional indexer, e.g. can pass
208208
# to ndarray.__getitem__
209+
# ScalarIndexer is for a single value as the index
210+
# SequenceIndexer is for list like or slices (but not tuples)
211+
# PositionalIndexerTuple is extends the PositionalIndexer for 2D arrays
212+
# These are used in various __getitem__ overloads
209213
# TODO: add Ellipsis, see
210214
# https://github.com/python/typing/issues/684#issuecomment-548203158
211215
# https://bugs.python.org/issue41810
212-
PositionalIndexer = Union[int, np.integer, slice, Sequence[int], np.ndarray]
213-
PositionalIndexer2D = Union[
214-
PositionalIndexer, Tuple[PositionalIndexer, PositionalIndexer]
215-
]
216+
# Using List[int] here rather than Sequence[int] to disallow tuples.
217+
ScalarIndexer = Union[int, np.integer]
218+
SequenceIndexer = Union[slice, List[int], np.ndarray]
219+
PositionalIndexer = Union[ScalarIndexer, SequenceIndexer]
220+
PositionalIndexerTuple = Tuple[PositionalIndexer, PositionalIndexer]
221+
PositionalIndexer2D = Union[PositionalIndexer, PositionalIndexerTuple]

pandas/core/arrays/_mixins.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
from typing import (
55
TYPE_CHECKING,
66
Any,
7+
Literal,
78
Sequence,
89
TypeVar,
910
cast,
11+
overload,
1012
)
1113

1214
import numpy as np
@@ -16,6 +18,9 @@
1618
from pandas._typing import (
1719
F,
1820
PositionalIndexer2D,
21+
PositionalIndexerTuple,
22+
ScalarIndexer,
23+
SequenceIndexer,
1924
Shape,
2025
npt,
2126
type_t,
@@ -48,7 +53,6 @@
4853
)
4954

5055
if TYPE_CHECKING:
51-
from typing import Literal
5256

5357
from pandas._typing import (
5458
NumpySorter,
@@ -205,6 +209,17 @@ def __setitem__(self, key, value):
205209
def _validate_setitem_value(self, value):
206210
return value
207211

212+
@overload
213+
def __getitem__(self, key: ScalarIndexer) -> Any:
214+
...
215+
216+
@overload
217+
def __getitem__(
218+
self: NDArrayBackedExtensionArrayT,
219+
key: SequenceIndexer | PositionalIndexerTuple,
220+
) -> NDArrayBackedExtensionArrayT:
221+
...
222+
208223
def __getitem__(
209224
self: NDArrayBackedExtensionArrayT,
210225
key: PositionalIndexer2D,

pandas/core/arrays/base.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
Dtype,
3131
FillnaOptions,
3232
PositionalIndexer,
33+
ScalarIndexer,
34+
SequenceIndexer,
3335
Shape,
3436
npt,
3537
)
@@ -298,8 +300,17 @@ def _from_factorized(cls, values, original):
298300
# ------------------------------------------------------------------------
299301
# Must be a Sequence
300302
# ------------------------------------------------------------------------
303+
@overload
304+
def __getitem__(self, item: ScalarIndexer) -> Any:
305+
...
306+
307+
@overload
308+
def __getitem__(self: ExtensionArrayT, item: SequenceIndexer) -> ExtensionArrayT:
309+
...
301310

302-
def __getitem__(self, item: PositionalIndexer) -> ExtensionArray | Any:
311+
def __getitem__(
312+
self: ExtensionArrayT, item: PositionalIndexer
313+
) -> ExtensionArrayT | Any:
303314
"""
304315
Select a subset of self.
305316
@@ -313,6 +324,8 @@ def __getitem__(self, item: PositionalIndexer) -> ExtensionArray | Any:
313324
314325
* ndarray: A 1-d boolean NumPy ndarray the same length as 'self'
315326
327+
* list[int]: A list of int
328+
316329
Returns
317330
-------
318331
item : scalar or ExtensionArray
@@ -761,7 +774,7 @@ def fillna(
761774
new_values = self.copy()
762775
return new_values
763776

764-
def dropna(self):
777+
def dropna(self: ExtensionArrayT) -> ExtensionArrayT:
765778
"""
766779
Return ExtensionArray without NA values.
767780

pandas/core/arrays/categorical.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from shutil import get_terminal_size
77
from typing import (
88
TYPE_CHECKING,
9+
Any,
910
Hashable,
1011
Sequence,
1112
TypeVar,
@@ -37,7 +38,11 @@
3738
Dtype,
3839
NpDtype,
3940
Ordered,
41+
PositionalIndexer2D,
42+
PositionalIndexerTuple,
4043
Scalar,
44+
ScalarIndexer,
45+
SequenceIndexer,
4146
Shape,
4247
npt,
4348
type_t,
@@ -2017,7 +2022,18 @@ def __repr__(self) -> str:
20172022

20182023
# ------------------------------------------------------------------
20192024

2020-
def __getitem__(self, key):
2025+
@overload
2026+
def __getitem__(self, key: ScalarIndexer) -> Any:
2027+
...
2028+
2029+
@overload
2030+
def __getitem__(
2031+
self: CategoricalT,
2032+
key: SequenceIndexer | PositionalIndexerTuple,
2033+
) -> CategoricalT:
2034+
...
2035+
2036+
def __getitem__(self: CategoricalT, key: PositionalIndexer2D) -> CategoricalT | Any:
20212037
"""
20222038
Return an item.
20232039
"""

pandas/core/arrays/datetimelike.py

+24-9
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@
4949
DtypeObj,
5050
NpDtype,
5151
PositionalIndexer2D,
52+
PositionalIndexerTuple,
53+
ScalarIndexer,
54+
SequenceIndexer,
5255
npt,
5356
)
5457
from pandas.compat.numpy import function as nv
@@ -313,17 +316,33 @@ def __array__(self, dtype: NpDtype | None = None) -> np.ndarray:
313316
return np.array(list(self), dtype=object)
314317
return self._ndarray
315318

319+
@overload
320+
def __getitem__(self, item: ScalarIndexer) -> DTScalarOrNaT:
321+
...
322+
323+
@overload
316324
def __getitem__(
317-
self, key: PositionalIndexer2D
318-
) -> DatetimeLikeArrayMixin | DTScalarOrNaT:
325+
self: DatetimeLikeArrayT,
326+
item: SequenceIndexer | PositionalIndexerTuple,
327+
) -> DatetimeLikeArrayT:
328+
...
329+
330+
def __getitem__(
331+
self: DatetimeLikeArrayT, key: PositionalIndexer2D
332+
) -> DatetimeLikeArrayT | DTScalarOrNaT:
319333
"""
320334
This getitem defers to the underlying array, which by-definition can
321335
only handle list-likes, slices, and integer scalars
322336
"""
323-
result = super().__getitem__(key)
337+
# Use cast as we know we will get back a DatetimeLikeArray or DTScalar
338+
result = cast(
339+
Union[DatetimeLikeArrayT, DTScalarOrNaT], super().__getitem__(key)
340+
)
324341
if lib.is_scalar(result):
325342
return result
326-
343+
else:
344+
# At this point we know the result is an array.
345+
result = cast(DatetimeLikeArrayT, result)
327346
result._freq = self._get_getitem_freq(key)
328347
return result
329348

@@ -1768,11 +1787,7 @@ def factorize(self, na_sentinel=-1, sort: bool = False):
17681787
uniques = self.copy() # TODO: copy or view?
17691788
if sort and self.freq.n < 0:
17701789
codes = codes[::-1]
1771-
# TODO: overload __getitem__, a slice indexer returns same type as self
1772-
# error: Incompatible types in assignment (expression has type
1773-
# "Union[DatetimeLikeArrayMixin, Union[Any, Any]]", variable
1774-
# has type "TimelikeOps")
1775-
uniques = uniques[::-1] # type: ignore[assignment]
1790+
uniques = uniques[::-1]
17761791
return codes, uniques
17771792
# FIXME: shouldn't get here; we are ignoring sort
17781793
return super().factorize(na_sentinel=na_sentinel)

pandas/core/arrays/datetimes.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from typing import (
1010
TYPE_CHECKING,
1111
Literal,
12-
cast,
1312
overload,
1413
)
1514
import warnings
@@ -478,11 +477,9 @@ def _generate_range(
478477
index = cls._simple_new(arr, freq=None, dtype=dtype)
479478

480479
if not left_closed and len(index) and index[0] == start:
481-
# TODO: overload DatetimeLikeArrayMixin.__getitem__
482-
index = cast(DatetimeArray, index[1:])
480+
index = index[1:]
483481
if not right_closed and len(index) and index[-1] == end:
484-
# TODO: overload DatetimeLikeArrayMixin.__getitem__
485-
index = cast(DatetimeArray, index[:-1])
482+
index = index[:-1]
486483

487484
dtype = tz_to_dtype(tz)
488485
return cls._simple_new(index._ndarray, freq=freq, dtype=dtype)

pandas/core/arrays/interval.py

+22-5
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
from typing import (
1010
Sequence,
1111
TypeVar,
12+
Union,
1213
cast,
14+
overload,
1315
)
1416

1517
import numpy as np
@@ -31,6 +33,9 @@
3133
ArrayLike,
3234
Dtype,
3335
NpDtype,
36+
PositionalIndexer,
37+
ScalarIndexer,
38+
SequenceIndexer,
3439
)
3540
from pandas.compat.numpy import function as nv
3641
from pandas.util._decorators import Appender
@@ -89,6 +94,7 @@
8994
)
9095

9196
IntervalArrayT = TypeVar("IntervalArrayT", bound="IntervalArray")
97+
IntervalOrNA = Union[Interval, float]
9298

9399
_interval_shared_docs: dict[str, str] = {}
94100

@@ -635,7 +641,17 @@ def __iter__(self):
635641
def __len__(self) -> int:
636642
return len(self._left)
637643

638-
def __getitem__(self, key):
644+
@overload
645+
def __getitem__(self, key: ScalarIndexer) -> IntervalOrNA:
646+
...
647+
648+
@overload
649+
def __getitem__(self: IntervalArrayT, key: SequenceIndexer) -> IntervalArrayT:
650+
...
651+
652+
def __getitem__(
653+
self: IntervalArrayT, key: PositionalIndexer
654+
) -> IntervalArrayT | IntervalOrNA:
639655
key = check_array_indexer(self, key)
640656
left = self._left[key]
641657
right = self._right[key]
@@ -1633,10 +1649,11 @@ def _from_combined(self, combined: np.ndarray) -> IntervalArray:
16331649
return self._shallow_copy(left=new_left, right=new_right)
16341650

16351651
def unique(self) -> IntervalArray:
1636-
# Invalid index type "Tuple[slice, int]" for "Union[ExtensionArray,
1637-
# ndarray[Any, Any]]"; expected type "Union[int, integer[Any], slice,
1638-
# Sequence[int], ndarray[Any, Any]]"
1639-
nc = unique(self._combined.view("complex128")[:, 0]) # type: ignore[index]
1652+
# No overload variant of "__getitem__" of "ExtensionArray" matches argument
1653+
# type "Tuple[slice, int]"
1654+
nc = unique(
1655+
self._combined.view("complex128")[:, 0] # type: ignore[call-overload]
1656+
)
16401657
nc = nc[:, None]
16411658
return self._from_combined(nc)
16421659

pandas/core/arrays/masked.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
NpDtype,
2121
PositionalIndexer,
2222
Scalar,
23+
ScalarIndexer,
24+
SequenceIndexer,
2325
npt,
2426
type_t,
2527
)
@@ -139,7 +141,17 @@ def __init__(self, values: np.ndarray, mask: np.ndarray, copy: bool = False):
139141
def dtype(self) -> BaseMaskedDtype:
140142
raise AbstractMethodError(self)
141143

142-
def __getitem__(self, item: PositionalIndexer) -> BaseMaskedArray | Any:
144+
@overload
145+
def __getitem__(self, item: ScalarIndexer) -> Any:
146+
...
147+
148+
@overload
149+
def __getitem__(self: BaseMaskedArrayT, item: SequenceIndexer) -> BaseMaskedArrayT:
150+
...
151+
152+
def __getitem__(
153+
self: BaseMaskedArrayT, item: PositionalIndexer
154+
) -> BaseMaskedArrayT | Any:
143155
if is_integer(item):
144156
if self._mask[item]:
145157
return self.dtype.na_value

pandas/core/arrays/period.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
TYPE_CHECKING,
77
Any,
88
Callable,
9+
Literal,
910
Sequence,
1011
)
1112

@@ -76,7 +77,6 @@
7677
import pandas.core.common as com
7778

7879
if TYPE_CHECKING:
79-
from typing import Literal
8080

8181
from pandas._typing import (
8282
NumpySorter,

0 commit comments

Comments
 (0)