From f945e38b99bdf80a9c607252d6d1b0444cda2b58 Mon Sep 17 00:00:00 2001 From: Laurent Mutricy Date: Thu, 4 Jul 2024 15:38:08 +0200 Subject: [PATCH 1/4] update algo.take to solve #59177 --- pandas/core/algorithms.py | 4 ++-- pandas/core/dtypes/generic.py | 9 ++++++++- pandas/tests/test_take.py | 8 ++++++++ 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/pandas/core/algorithms.py b/pandas/core/algorithms.py index 0d97f8a298fdb..16444150aaf7e 100644 --- a/pandas/core/algorithms.py +++ b/pandas/core/algorithms.py @@ -1164,8 +1164,8 @@ def take( if not isinstance(arr, (np.ndarray, ABCExtensionArray, ABCIndex, ABCSeries)): # GH#52981 raise TypeError( - "pd.api.extensions.take requires a numpy.ndarray, " - f"ExtensionArray, Index, or Series, got {type(arr).__name__}." + "pd.api.extensions.take requires a numpy.ndarray, ExtensionArray, " + f"Index, Series, or NumpyExtensionArray got {type(arr).__name__}." ) indices = ensure_platform_int(indices) diff --git a/pandas/core/dtypes/generic.py b/pandas/core/dtypes/generic.py index 8d3d86217dedf..d2ab4e982bd81 100644 --- a/pandas/core/dtypes/generic.py +++ b/pandas/core/dtypes/generic.py @@ -139,7 +139,14 @@ def _subclasscheck(cls, inst) -> bool: "ABCExtensionArray", "_typ", # Note: IntervalArray and SparseArray are included bc they have _typ="extension" - {"extension", "categorical", "periodarray", "datetimearray", "timedeltaarray"}, + { + "extension", + "categorical", + "periodarray", + "datetimearray", + "timedeltaarray", + "npy_extension", + }, ), ) ABCNumpyExtensionArray = cast( diff --git a/pandas/tests/test_take.py b/pandas/tests/test_take.py index ce2e4e0f6cec5..0ef6bd62ef0ac 100644 --- a/pandas/tests/test_take.py +++ b/pandas/tests/test_take.py @@ -5,6 +5,7 @@ from pandas._libs import iNaT +from pandas import array import pandas._testing as tm import pandas.core.algorithms as algos @@ -307,3 +308,10 @@ def test_take_coerces_list(self): ) with pytest.raises(TypeError, match=msg): algos.take(arr, [0, 0]) + + def test_take_NumpyExtensionArray(self): + # GH#59177 + arr = array([1 + 1j, 2, 3]) # NumpyEADtype('complex128') (NumpyExtensionArray) + assert algos.take(arr, [2]) == 2 + arr = array([1, 2, 3]) # Int64Dtype() (ExtensionArray) + assert algos.take(arr, [2]) == 2 From 78e50d9fc38d004f561c8a25a7af077833174ebc Mon Sep 17 00:00:00 2001 From: Laurent Mutricy Date: Thu, 4 Jul 2024 16:51:30 +0200 Subject: [PATCH 2/4] forgot to update TestExtensionTake::test_take_coerces_list --- pandas/tests/test_take.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/tests/test_take.py b/pandas/tests/test_take.py index 0ef6bd62ef0ac..451ef42fff3d1 100644 --- a/pandas/tests/test_take.py +++ b/pandas/tests/test_take.py @@ -304,7 +304,7 @@ def test_take_coerces_list(self): arr = [1, 2, 3] msg = ( "pd.api.extensions.take requires a numpy.ndarray, ExtensionArray, " - "Index, or Series, got list" + "Index, Series, or NumpyExtensionArray got list" ) with pytest.raises(TypeError, match=msg): algos.take(arr, [0, 0]) From a9cdf6e4f98a4894f2e2b11df7a22f3f57fb5c48 Mon Sep 17 00:00:00 2001 From: Laurent Mutricy Date: Fri, 5 Jul 2024 09:27:46 +0200 Subject: [PATCH 3/4] fixing pandas/tests/dtypes/test_generic.py::TestABCClasses::test_abc_hierarchy --- pandas/core/algorithms.py | 6 +++++- pandas/core/dtypes/generic.py | 1 - 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/pandas/core/algorithms.py b/pandas/core/algorithms.py index 16444150aaf7e..92bd55cac9c5e 100644 --- a/pandas/core/algorithms.py +++ b/pandas/core/algorithms.py @@ -68,6 +68,7 @@ ABCExtensionArray, ABCIndex, ABCMultiIndex, + ABCNumpyExtensionArray, ABCSeries, ABCTimedeltaArray, ) @@ -1161,7 +1162,10 @@ def take( ... ) array([ 10, 10, -10]) """ - if not isinstance(arr, (np.ndarray, ABCExtensionArray, ABCIndex, ABCSeries)): + if not isinstance( + arr, + (np.ndarray, ABCExtensionArray, ABCIndex, ABCSeries, ABCNumpyExtensionArray), + ): # GH#52981 raise TypeError( "pd.api.extensions.take requires a numpy.ndarray, ExtensionArray, " diff --git a/pandas/core/dtypes/generic.py b/pandas/core/dtypes/generic.py index d2ab4e982bd81..3bd7e0573bf2e 100644 --- a/pandas/core/dtypes/generic.py +++ b/pandas/core/dtypes/generic.py @@ -145,7 +145,6 @@ def _subclasscheck(cls, inst) -> bool: "periodarray", "datetimearray", "timedeltaarray", - "npy_extension", }, ), ) From cfaea959ef005d17440a8e5a3991db8be9b4b554 Mon Sep 17 00:00:00 2001 From: Laurent Mutricy Date: Sat, 6 Jul 2024 13:36:26 +0200 Subject: [PATCH 4/4] ABCExtensionArray set formatting --- pandas/core/dtypes/generic.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/pandas/core/dtypes/generic.py b/pandas/core/dtypes/generic.py index 3bd7e0573bf2e..8d3d86217dedf 100644 --- a/pandas/core/dtypes/generic.py +++ b/pandas/core/dtypes/generic.py @@ -139,13 +139,7 @@ def _subclasscheck(cls, inst) -> bool: "ABCExtensionArray", "_typ", # Note: IntervalArray and SparseArray are included bc they have _typ="extension" - { - "extension", - "categorical", - "periodarray", - "datetimearray", - "timedeltaarray", - }, + {"extension", "categorical", "periodarray", "datetimearray", "timedeltaarray"}, ), ) ABCNumpyExtensionArray = cast(