Skip to content

Commit 4f738a4

Browse files
TYP: __getitem__ method of EA (#37898)
1 parent 8f0e263 commit 4f738a4

File tree

8 files changed

+59
-30
lines changed

8 files changed

+59
-30
lines changed

pandas/core/arrays/_mixins.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from typing import Any, Optional, Sequence, Type, TypeVar
1+
from __future__ import annotations
2+
3+
from typing import Any, Optional, Sequence, Type, TypeVar, Union
24

35
import numpy as np
46

@@ -212,7 +214,9 @@ def __setitem__(self, key, value):
212214
def _validate_setitem_value(self, value):
213215
return value
214216

215-
def __getitem__(self, key):
217+
def __getitem__(
218+
self: NDArrayBackedExtensionArrayT, key: Union[int, slice, np.ndarray]
219+
) -> Union[NDArrayBackedExtensionArrayT, Any]:
216220
if lib.is_integer(key):
217221
# fast-path
218222
result = self._ndarray[key]

pandas/core/arrays/base.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
This is an experimental API and subject to breaking changes
77
without warning.
88
"""
9+
from __future__ import annotations
10+
911
import operator
1012
from typing import (
1113
Any,
@@ -254,8 +256,9 @@ def _from_factorized(cls, values, original):
254256
# Must be a Sequence
255257
# ------------------------------------------------------------------------
256258

257-
def __getitem__(self, item):
258-
# type (Any) -> Any
259+
def __getitem__(
260+
self, item: Union[int, slice, np.ndarray]
261+
) -> Union[ExtensionArray, Any]:
259262
"""
260263
Select a subset of self.
261264
@@ -661,7 +664,7 @@ def dropna(self):
661664
"""
662665
return self[~self.isna()]
663666

664-
def shift(self, periods: int = 1, fill_value: object = None) -> "ExtensionArray":
667+
def shift(self, periods: int = 1, fill_value: object = None) -> ExtensionArray:
665668
"""
666669
Shift values by desired number.
667670
@@ -831,7 +834,7 @@ def _values_for_factorize(self) -> Tuple[np.ndarray, Any]:
831834
"""
832835
return self.astype(object), np.nan
833836

834-
def factorize(self, na_sentinel: int = -1) -> Tuple[np.ndarray, "ExtensionArray"]:
837+
def factorize(self, na_sentinel: int = -1) -> Tuple[np.ndarray, ExtensionArray]:
835838
"""
836839
Encode the extension array as an enumerated type.
837840
@@ -940,7 +943,7 @@ def take(
940943
*,
941944
allow_fill: bool = False,
942945
fill_value: Any = None,
943-
) -> "ExtensionArray":
946+
) -> ExtensionArray:
944947
"""
945948
Take elements from an array.
946949
@@ -1109,7 +1112,7 @@ def _formatter(self, boxed: bool = False) -> Callable[[Any], Optional[str]]:
11091112
# Reshaping
11101113
# ------------------------------------------------------------------------
11111114

1112-
def transpose(self, *axes) -> "ExtensionArray":
1115+
def transpose(self, *axes) -> ExtensionArray:
11131116
"""
11141117
Return a transposed view on this array.
11151118
@@ -1119,10 +1122,10 @@ def transpose(self, *axes) -> "ExtensionArray":
11191122
return self[:]
11201123

11211124
@property
1122-
def T(self) -> "ExtensionArray":
1125+
def T(self) -> ExtensionArray:
11231126
return self.transpose()
11241127

1125-
def ravel(self, order="C") -> "ExtensionArray":
1128+
def ravel(self, order="C") -> ExtensionArray:
11261129
"""
11271130
Return a flattened view on this array.
11281131

pandas/core/arrays/datetimelike.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from datetime import datetime, timedelta
24
import operator
35
from typing import (
@@ -264,7 +266,9 @@ def __array__(self, dtype=None) -> np.ndarray:
264266
return np.array(list(self), dtype=object)
265267
return self._ndarray
266268

267-
def __getitem__(self, key):
269+
def __getitem__(
270+
self, key: Union[int, slice, np.ndarray]
271+
) -> Union[DatetimeLikeArrayMixin, DTScalarOrNaT]:
268272
"""
269273
This getitem defers to the underlying array, which by-definition can
270274
only handle list-likes, slices, and integer scalars

pandas/core/arrays/datetimes.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from datetime import datetime, time, timedelta, tzinfo
2-
from typing import Optional, Union
2+
from typing import Optional, Union, cast
33
import warnings
44

55
import numpy as np
@@ -444,9 +444,11 @@ def _generate_range(
444444
)
445445

446446
if not left_closed and len(index) and index[0] == start:
447-
index = index[1:]
447+
# TODO: overload DatetimeLikeArrayMixin.__getitem__
448+
index = cast(DatetimeArray, index[1:])
448449
if not right_closed and len(index) and index[-1] == end:
449-
index = index[:-1]
450+
# TODO: overload DatetimeLikeArrayMixin.__getitem__
451+
index = cast(DatetimeArray, index[:-1])
450452

451453
dtype = tz_to_dtype(tz)
452454
return cls._simple_new(index.asi8, freq=freq, dtype=dtype)

pandas/core/arrays/masked.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Type, TypeVar
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Any, Optional, Sequence, Tuple, Type, TypeVar, Union
24

35
import numpy as np
46

@@ -56,7 +58,7 @@ def itemsize(self) -> int:
5658
return self.numpy_dtype.itemsize
5759

5860
@classmethod
59-
def construct_array_type(cls) -> Type["BaseMaskedArray"]:
61+
def construct_array_type(cls) -> Type[BaseMaskedArray]:
6062
"""
6163
Return the array type associated with this dtype.
6264
@@ -100,7 +102,9 @@ def __init__(self, values: np.ndarray, mask: np.ndarray, copy: bool = False):
100102
def dtype(self) -> BaseMaskedDtype:
101103
raise AbstractMethodError(self)
102104

103-
def __getitem__(self, item):
105+
def __getitem__(
106+
self, item: Union[int, slice, np.ndarray]
107+
) -> Union[BaseMaskedArray, Any]:
104108
if is_integer(item):
105109
if self._mask[item]:
106110
return self.dtype.na_value

pandas/core/groupby/groupby.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1671,10 +1671,10 @@ def first(self, numeric_only: bool = False, min_count: int = -1):
16711671
def first_compat(obj: FrameOrSeries, axis: int = 0):
16721672
def first(x: Series):
16731673
"""Helper function for first item that isn't NA."""
1674-
x = x.array[notna(x.array)]
1675-
if len(x) == 0:
1674+
arr = x.array[notna(x.array)]
1675+
if not len(arr):
16761676
return np.nan
1677-
return x[0]
1677+
return arr[0]
16781678

16791679
if isinstance(obj, DataFrame):
16801680
return obj.apply(first, axis=axis)
@@ -1695,10 +1695,10 @@ def last(self, numeric_only: bool = False, min_count: int = -1):
16951695
def last_compat(obj: FrameOrSeries, axis: int = 0):
16961696
def last(x: Series):
16971697
"""Helper function for last item that isn't NA."""
1698-
x = x.array[notna(x.array)]
1699-
if len(x) == 0:
1698+
arr = x.array[notna(x.array)]
1699+
if not len(arr):
17001700
return np.nan
1701-
return x[-1]
1701+
return arr[-1]
17021702

17031703
if isinstance(obj, DataFrame):
17041704
return obj.apply(last, axis=axis)

pandas/core/indexes/base.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -3222,8 +3222,14 @@ def _get_nearest_indexer(self, target: "Index", limit, tolerance) -> np.ndarray:
32223222
right_indexer = self.get_indexer(target, "backfill", limit=limit)
32233223

32243224
target_values = target._values
3225-
left_distances = np.abs(self._values[left_indexer] - target_values)
3226-
right_distances = np.abs(self._values[right_indexer] - target_values)
3225+
# error: Unsupported left operand type for - ("ExtensionArray")
3226+
left_distances = np.abs(
3227+
self._values[left_indexer] - target_values # type: ignore[operator]
3228+
)
3229+
# error: Unsupported left operand type for - ("ExtensionArray")
3230+
right_distances = np.abs(
3231+
self._values[right_indexer] - target_values # type: ignore[operator]
3232+
)
32273233

32283234
op = operator.lt if self.is_monotonic_increasing else operator.le
32293235
indexer = np.where(
@@ -3242,7 +3248,8 @@ def _filter_indexer_tolerance(
32423248
indexer: np.ndarray,
32433249
tolerance,
32443250
) -> np.ndarray:
3245-
distance = abs(self._values[indexer] - target)
3251+
# error: Unsupported left operand type for - ("ExtensionArray")
3252+
distance = abs(self._values[indexer] - target) # type: ignore[operator]
32463253
indexer = np.where(distance <= tolerance, indexer, -1)
32473254
return indexer
32483255

@@ -3446,6 +3453,7 @@ def reindex(self, target, method=None, level=None, limit=None, tolerance=None):
34463453
target = ensure_has_len(target) # target may be an iterator
34473454

34483455
if not isinstance(target, Index) and len(target) == 0:
3456+
values: Union[range, ExtensionArray, np.ndarray]
34493457
if isinstance(self, ABCRangeIndex):
34503458
values = range(0)
34513459
else:
@@ -4538,8 +4546,9 @@ def asof_locs(self, where: "Index", mask) -> np.ndarray:
45384546

45394547
result = np.arange(len(self))[mask].take(locs)
45404548

4541-
first = mask.argmax()
4542-
result[(locs == 0) & (where._values < self._values[first])] = -1
4549+
# TODO: overload return type of ExtensionArray.__getitem__
4550+
first_value = cast(Any, self._values[mask.argmax()])
4551+
result[(locs == 0) & (where._values < first_value)] = -1
45434552

45444553
return result
45454554

pandas/core/indexes/period.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from datetime import datetime, timedelta
2-
from typing import Any
2+
from typing import Any, cast
33

44
import numpy as np
55

@@ -673,7 +673,10 @@ def difference(self, other, sort=None):
673673

674674
if self.equals(other):
675675
# pass an empty PeriodArray with the appropriate dtype
676-
return type(self)._simple_new(self._data[:0], name=self.name)
676+
677+
# TODO: overload DatetimeLikeArrayMixin.__getitem__
678+
values = cast(PeriodArray, self._data[:0])
679+
return type(self)._simple_new(values, name=self.name)
677680

678681
if is_object_dtype(other):
679682
return self.astype(object).difference(other).astype(self.dtype)

0 commit comments

Comments
 (0)