Skip to content

Commit f945e38

Browse files
author
Laurent Mutricy
committed
update algo.take to solve pandas-dev#59177
1 parent dcb5494 commit f945e38

File tree

3 files changed

+18
-3
lines changed

3 files changed

+18
-3
lines changed

pandas/core/algorithms.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1164,8 +1164,8 @@ def take(
11641164
if not isinstance(arr, (np.ndarray, ABCExtensionArray, ABCIndex, ABCSeries)):
11651165
# GH#52981
11661166
raise TypeError(
1167-
"pd.api.extensions.take requires a numpy.ndarray, "
1168-
f"ExtensionArray, Index, or Series, got {type(arr).__name__}."
1167+
"pd.api.extensions.take requires a numpy.ndarray, ExtensionArray, "
1168+
f"Index, Series, or NumpyExtensionArray got {type(arr).__name__}."
11691169
)
11701170

11711171
indices = ensure_platform_int(indices)

pandas/core/dtypes/generic.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,14 @@ def _subclasscheck(cls, inst) -> bool:
139139
"ABCExtensionArray",
140140
"_typ",
141141
# Note: IntervalArray and SparseArray are included bc they have _typ="extension"
142-
{"extension", "categorical", "periodarray", "datetimearray", "timedeltaarray"},
142+
{
143+
"extension",
144+
"categorical",
145+
"periodarray",
146+
"datetimearray",
147+
"timedeltaarray",
148+
"npy_extension",
149+
},
143150
),
144151
)
145152
ABCNumpyExtensionArray = cast(

pandas/tests/test_take.py

+8
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from pandas._libs import iNaT
77

8+
from pandas import array
89
import pandas._testing as tm
910
import pandas.core.algorithms as algos
1011

@@ -307,3 +308,10 @@ def test_take_coerces_list(self):
307308
)
308309
with pytest.raises(TypeError, match=msg):
309310
algos.take(arr, [0, 0])
311+
312+
def test_take_NumpyExtensionArray(self):
313+
# GH#59177
314+
arr = array([1 + 1j, 2, 3]) # NumpyEADtype('complex128') (NumpyExtensionArray)
315+
assert algos.take(arr, [2]) == 2
316+
arr = array([1, 2, 3]) # Int64Dtype() (ExtensionArray)
317+
assert algos.take(arr, [2]) == 2

0 commit comments

Comments
 (0)