Skip to content

TYP: ExtensionArray.take accept np.ndarray #43418

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Sep 26, 2021
4 changes: 4 additions & 0 deletions pandas/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,10 @@
PositionalIndexer = Union[ScalarIndexer, SequenceIndexer]
PositionalIndexerTuple = Tuple[PositionalIndexer, PositionalIndexer]
PositionalIndexer2D = Union[PositionalIndexer, PositionalIndexerTuple]
if TYPE_CHECKING:
TakeIndexer = Union[Sequence[int], Sequence[np.integer], npt.NDArray[np.integer]]
else:
TakeIndexer = Any

# Windowing rank methods
WindowingRankType = Literal["average", "min", "max"]
9 changes: 7 additions & 2 deletions pandas/core/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
ArrayLike,
DtypeObj,
Scalar,
TakeIndexer,
npt,
)
from pandas.util._decorators import doc
Expand Down Expand Up @@ -1431,7 +1432,11 @@ def get_indexer(current_indexer, other_indexer):


def take(
arr, indices: np.ndarray, axis: int = 0, allow_fill: bool = False, fill_value=None
arr,
indices: TakeIndexer,
axis: int = 0,
allow_fill: bool = False,
fill_value=None,
):
"""
Take elements from an array.
Expand All @@ -1441,7 +1446,7 @@ def take(
arr : array-like or scalar value
Non array-likes (sequences/scalars without a dtype) are coerced
to an ndarray.
indices : sequence of integers
indices : sequence of int or one-dimensional np.ndarray of int
Indices to be taken.
axis : int, default 0
The axis over which to select values.
Expand Down
9 changes: 1 addition & 8 deletions pandas/core/array_algos/take.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,7 @@ def take_1d(
"""
if not isinstance(arr, np.ndarray):
# ExtensionArray -> dispatch to their method

# error: Argument 1 to "take" of "ExtensionArray" has incompatible type
# "ndarray"; expected "Sequence[int]"
return arr.take(
indexer, # type: ignore[arg-type]
fill_value=fill_value,
allow_fill=allow_fill,
)
return arr.take(indexer, fill_value=fill_value, allow_fill=allow_fill)

if not allow_fill:
return arr.take(indexer)
Expand Down
7 changes: 3 additions & 4 deletions pandas/core/arrays/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
ScalarIndexer,
SequenceIndexer,
Shape,
TakeIndexer,
npt,
type_t,
)
Expand Down Expand Up @@ -101,7 +102,7 @@ def _validate_scalar(self, value):

def take(
self: NDArrayBackedExtensionArrayT,
indices: Sequence[int],
indices: TakeIndexer,
*,
allow_fill: bool = False,
fill_value: Any = None,
Expand All @@ -112,9 +113,7 @@ def take(

new_data = take(
self._ndarray,
# error: Argument 2 to "take" has incompatible type "Sequence[int]";
# expected "ndarray"
indices, # type: ignore[arg-type]
indices,
allow_fill=allow_fill,
fill_value=fill_value,
axis=axis,
Expand Down
5 changes: 3 additions & 2 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
ScalarIndexer,
SequenceIndexer,
Shape,
TakeIndexer,
npt,
)
from pandas.compat import set_function_name
Expand Down Expand Up @@ -1076,7 +1077,7 @@ def repeat(self, repeats: int | Sequence[int], axis: int | None = None):

def take(
self: ExtensionArrayT,
indices: Sequence[int],
indices: TakeIndexer,
*,
allow_fill: bool = False,
fill_value: Any = None,
Expand All @@ -1086,7 +1087,7 @@ def take(
Parameters
----------
indices : sequence of int
indices : sequence of int or one-dimensional np.ndarray of int
Indices to be taken.
allow_fill : bool, default False
How to handle negative values in `indices`.
Expand Down
13 changes: 7 additions & 6 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import (
TYPE_CHECKING,
Any,
Sequence,
Union,
cast,
overload,
Expand All @@ -24,6 +23,7 @@
Scalar,
ScalarIndexer,
SequenceIndexer,
TakeIndexer,
npt,
)
from pandas.compat import (
Expand Down Expand Up @@ -307,9 +307,7 @@ def __getitem__(
if not len(item):
return type(self)(pa.chunked_array([], type=pa.string()))
elif is_integer_dtype(item.dtype):
# error: Argument 1 to "take" of "ArrowStringArray" has incompatible
# type "ndarray"; expected "Sequence[int]"
return self.take(item) # type: ignore[arg-type]
return self.take(item)
elif is_bool_dtype(item.dtype):
return type(self)(self._data.filter(item))
else:
Expand Down Expand Up @@ -513,14 +511,17 @@ def __setitem__(self, key: int | slice | np.ndarray, value: Any) -> None:
self[k] = v

def take(
self, indices: Sequence[int], allow_fill: bool = False, fill_value: Any = None
self,
indices: TakeIndexer,
allow_fill: bool = False,
fill_value: Any = None,
):
"""
Take elements from an array.
Parameters
----------
indices : sequence of int
indices : sequence of int or one-dimensional np.ndarray of int
Indices to be taken.
allow_fill : bool, default False
How to handle negative values in `indices`.
Expand Down
8 changes: 2 additions & 6 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4318,12 +4318,8 @@ def _join_non_unique(
)
mask = left_idx == -1

# error: Argument 1 to "take" of "ExtensionArray" has incompatible
# type "ndarray[Any, dtype[signedinteger[Any]]]"; expected "Sequence[int]"
join_array = self._values.take(left_idx) # type: ignore[arg-type]
# error: Argument 1 to "take" of "ExtensionArray" has incompatible type
# "ndarray[Any, dtype[signedinteger[Any]]]"; expected "Sequence[int]"
right = other._values.take(right_idx) # type: ignore[arg-type]
join_array = self._values.take(left_idx)
right = other._values.take(right_idx)

if isinstance(join_array, np.ndarray):
np.putmask(join_array, mask, right)
Expand Down