Skip to content

Commit dd7158b

Browse files
authored
TYP: ExtensionArray.take accept np.ndarray (#43418)
1 parent d1f9a73 commit dd7158b

File tree

7 files changed

+27
-28
lines changed

7 files changed

+27
-28
lines changed

pandas/_typing.py

+4
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,10 @@
219219
PositionalIndexer = Union[ScalarIndexer, SequenceIndexer]
220220
PositionalIndexerTuple = Tuple[PositionalIndexer, PositionalIndexer]
221221
PositionalIndexer2D = Union[PositionalIndexer, PositionalIndexerTuple]
222+
if TYPE_CHECKING:
223+
TakeIndexer = Union[Sequence[int], Sequence[np.integer], npt.NDArray[np.integer]]
224+
else:
225+
TakeIndexer = Any
222226

223227
# Windowing rank methods
224228
WindowingRankType = Literal["average", "min", "max"]

pandas/core/algorithms.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
ArrayLike,
2929
DtypeObj,
3030
Scalar,
31+
TakeIndexer,
3132
npt,
3233
)
3334
from pandas.util._decorators import doc
@@ -1431,7 +1432,11 @@ def get_indexer(current_indexer, other_indexer):
14311432

14321433

14331434
def take(
1434-
arr, indices: np.ndarray, axis: int = 0, allow_fill: bool = False, fill_value=None
1435+
arr,
1436+
indices: TakeIndexer,
1437+
axis: int = 0,
1438+
allow_fill: bool = False,
1439+
fill_value=None,
14351440
):
14361441
"""
14371442
Take elements from an array.
@@ -1441,7 +1446,7 @@ def take(
14411446
arr : array-like or scalar value
14421447
Non array-likes (sequences/scalars without a dtype) are coerced
14431448
to an ndarray.
1444-
indices : sequence of integers
1449+
indices : sequence of int or one-dimensional np.ndarray of int
14451450
Indices to be taken.
14461451
axis : int, default 0
14471452
The axis over which to select values.

pandas/core/array_algos/take.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -178,14 +178,7 @@ def take_1d(
178178
"""
179179
if not isinstance(arr, np.ndarray):
180180
# ExtensionArray -> dispatch to their method
181-
182-
# error: Argument 1 to "take" of "ExtensionArray" has incompatible type
183-
# "ndarray"; expected "Sequence[int]"
184-
return arr.take(
185-
indexer, # type: ignore[arg-type]
186-
fill_value=fill_value,
187-
allow_fill=allow_fill,
188-
)
181+
return arr.take(indexer, fill_value=fill_value, allow_fill=allow_fill)
189182

190183
if not allow_fill:
191184
return arr.take(indexer)

pandas/core/arrays/_mixins.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
ScalarIndexer,
2323
SequenceIndexer,
2424
Shape,
25+
TakeIndexer,
2526
npt,
2627
type_t,
2728
)
@@ -101,7 +102,7 @@ def _validate_scalar(self, value):
101102

102103
def take(
103104
self: NDArrayBackedExtensionArrayT,
104-
indices: Sequence[int],
105+
indices: TakeIndexer,
105106
*,
106107
allow_fill: bool = False,
107108
fill_value: Any = None,
@@ -112,9 +113,7 @@ def take(
112113

113114
new_data = take(
114115
self._ndarray,
115-
# error: Argument 2 to "take" has incompatible type "Sequence[int]";
116-
# expected "ndarray"
117-
indices, # type: ignore[arg-type]
116+
indices,
118117
allow_fill=allow_fill,
119118
fill_value=fill_value,
120119
axis=axis,

pandas/core/arrays/base.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
ScalarIndexer,
3434
SequenceIndexer,
3535
Shape,
36+
TakeIndexer,
3637
npt,
3738
)
3839
from pandas.compat import set_function_name
@@ -1076,7 +1077,7 @@ def repeat(self, repeats: int | Sequence[int], axis: int | None = None):
10761077

10771078
def take(
10781079
self: ExtensionArrayT,
1079-
indices: Sequence[int],
1080+
indices: TakeIndexer,
10801081
*,
10811082
allow_fill: bool = False,
10821083
fill_value: Any = None,
@@ -1086,7 +1087,7 @@ def take(
10861087
10871088
Parameters
10881089
----------
1089-
indices : sequence of int
1090+
indices : sequence of int or one-dimensional np.ndarray of int
10901091
Indices to be taken.
10911092
allow_fill : bool, default False
10921093
How to handle negative values in `indices`.

pandas/core/arrays/string_arrow.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import (
66
TYPE_CHECKING,
77
Any,
8-
Sequence,
98
Union,
109
cast,
1110
overload,
@@ -24,6 +23,7 @@
2423
Scalar,
2524
ScalarIndexer,
2625
SequenceIndexer,
26+
TakeIndexer,
2727
npt,
2828
)
2929
from pandas.compat import (
@@ -307,9 +307,7 @@ def __getitem__(
307307
if not len(item):
308308
return type(self)(pa.chunked_array([], type=pa.string()))
309309
elif is_integer_dtype(item.dtype):
310-
# error: Argument 1 to "take" of "ArrowStringArray" has incompatible
311-
# type "ndarray"; expected "Sequence[int]"
312-
return self.take(item) # type: ignore[arg-type]
310+
return self.take(item)
313311
elif is_bool_dtype(item.dtype):
314312
return type(self)(self._data.filter(item))
315313
else:
@@ -513,14 +511,17 @@ def __setitem__(self, key: int | slice | np.ndarray, value: Any) -> None:
513511
self[k] = v
514512

515513
def take(
516-
self, indices: Sequence[int], allow_fill: bool = False, fill_value: Any = None
514+
self,
515+
indices: TakeIndexer,
516+
allow_fill: bool = False,
517+
fill_value: Any = None,
517518
):
518519
"""
519520
Take elements from an array.
520521
521522
Parameters
522523
----------
523-
indices : sequence of int
524+
indices : sequence of int or one-dimensional np.ndarray of int
524525
Indices to be taken.
525526
allow_fill : bool, default False
526527
How to handle negative values in `indices`.

pandas/core/indexes/base.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -4320,12 +4320,8 @@ def _join_non_unique(
43204320
)
43214321
mask = left_idx == -1
43224322

4323-
# error: Argument 1 to "take" of "ExtensionArray" has incompatible
4324-
# type "ndarray[Any, dtype[signedinteger[Any]]]"; expected "Sequence[int]"
4325-
join_array = self._values.take(left_idx) # type: ignore[arg-type]
4326-
# error: Argument 1 to "take" of "ExtensionArray" has incompatible type
4327-
# "ndarray[Any, dtype[signedinteger[Any]]]"; expected "Sequence[int]"
4328-
right = other._values.take(right_idx) # type: ignore[arg-type]
4323+
join_array = self._values.take(left_idx)
4324+
right = other._values.take(right_idx)
43294325

43304326
if isinstance(join_array, np.ndarray):
43314327
np.putmask(join_array, mask, right)

0 commit comments

Comments
 (0)