diff --git a/pandas/compat/numpy/__init__.py b/pandas/compat/numpy/__init__.py index a2444b7ba5a0d..b82a10fd90ce9 100644 --- a/pandas/compat/numpy/__init__.py +++ b/pandas/compat/numpy/__init__.py @@ -62,6 +62,21 @@ def np_array_datetime64_compat(arr, *args, **kwargs): return np.array(arr, *args, **kwargs) +# taken from dask array +# https://github.com/dask/dask/blob/master/dask/array/utils.py#L352-L363 +def _is_nep18_active(): + class A: + def __array_function__(self, *args, **kwargs): + return True + + try: + return np.concatenate([A()]) + except ValueError: + return False + + +IS_NEP18_ACTIVE = _is_nep18_active() + __all__ = [ "np", "_np_version", diff --git a/pandas/core/arrays/__init__.py b/pandas/core/arrays/__init__.py index 1d538824e6d82..136345b8345c5 100644 --- a/pandas/core/arrays/__init__.py +++ b/pandas/core/arrays/__init__.py @@ -1,4 +1,5 @@ from pandas.core.arrays.base import ( + ArrayFunctionMixin, ExtensionArray, ExtensionOpsMixin, ExtensionScalarOpsMixin, @@ -15,6 +16,7 @@ from pandas.core.arrays.timedeltas import TimedeltaArray __all__ = [ + "ArrayFunctionMixin", "ExtensionArray", "ExtensionOpsMixin", "ExtensionScalarOpsMixin", diff --git a/pandas/core/arrays/_mixins.py b/pandas/core/arrays/_mixins.py index e9d8671b69c78..209220eb528c6 100644 --- a/pandas/core/arrays/_mixins.py +++ b/pandas/core/arrays/_mixins.py @@ -8,12 +8,12 @@ from pandas.core.algorithms import take, unique from pandas.core.array_algos.transforms import shift -from pandas.core.arrays.base import ExtensionArray +from pandas.core.arrays.base import ArrayFunctionMixin, ExtensionArray _T = TypeVar("_T", bound="NDArrayBackedExtensionArray") -class NDArrayBackedExtensionArray(ExtensionArray): +class NDArrayBackedExtensionArray(ExtensionArray, ArrayFunctionMixin): """ ExtensionArray that is backed by a single NumPy ndarray. """ @@ -156,3 +156,19 @@ def _validate_shift_value(self, fill_value): # TODO: after deprecation in datetimelikearraymixin is enforced, # we can remove this and ust validate_fill_value directly return self._validate_fill_value(fill_value) + + # ------------------------------------------------------------------------ + # __array_function__ compat + # ------------------------------------------------------------------------ + + def min(self, skipna: bool = True, **kwargs): + raise AbstractMethodError(self) + + def amin(self, skipna: bool = True, **kwargs): + return self.min(skipna=skipna, **kwargs) + + def max(self, skipna: bool = True, **kwargs): + raise AbstractMethodError(self) + + def amax(self, skipna: bool = True, **kwargs): + return self.max(skipna=skipna, **kwargs) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index e93cdb608dffb..ba4e4ba3ed50d 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -12,7 +12,7 @@ import numpy as np from pandas._libs import lib -from pandas._typing import ArrayLike +from pandas._typing import ArrayLike, F from pandas.compat import set_function_name from pandas.compat.numpy import function as nv from pandas.errors import AbstractMethodError @@ -37,6 +37,41 @@ _extension_array_shared_docs: Dict[str, str] = dict() +_HANDLED_FUNCTIONS = {} + + +def implements(numpy_function) -> Callable[[F], F]: + """ + Register an __array_function__ implementation for ExtensionArray objects. + """ + + def decorator(func): + _HANDLED_FUNCTIONS[numpy_function] = func + return func + + return decorator + + +class ArrayFunctionMixin: + def __array_function__(self, func, types, args, kwargs): + if func not in _HANDLED_FUNCTIONS: + # try to find a matching method name. If that doesn't work, we may + # be dealing with an alias or a function that's simply not in the + # ExtensionArray API. Handle aliases via the _HANDLED_FUNCTIONS + # dict mapping. + exclude_list = {"unique"} + ea_func = getattr(type(self), func.__name__, None) + if ( + ea_func is None + or not callable(ea_func) + or ea_func.__name__ in exclude_list + ): + return func.__wrapped__(*args, **kwargs) + + return ea_func(*args, **kwargs) + + return _HANDLED_FUNCTIONS[func](*args, **kwargs) + class ExtensionArray: """ diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py index 436a7dd062c4a..6cb6ae3f39e61 100644 --- a/pandas/core/arrays/interval.py +++ b/pandas/core/arrays/interval.py @@ -40,7 +40,11 @@ from pandas.core.dtypes.missing import isna, notna from pandas.core.algorithms import take, value_counts -from pandas.core.arrays.base import ExtensionArray, _extension_array_shared_docs +from pandas.core.arrays.base import ( + ArrayFunctionMixin, + ExtensionArray, + _extension_array_shared_docs, +) from pandas.core.arrays.categorical import Categorical import pandas.core.common as com from pandas.core.construction import array @@ -145,7 +149,7 @@ ), ) ) -class IntervalArray(IntervalMixin, ExtensionArray): +class IntervalArray(IntervalMixin, ExtensionArray, ArrayFunctionMixin): ndim = 1 can_hold_na = True _na_value = _fill_value = np.nan diff --git a/pandas/core/arrays/masked.py b/pandas/core/arrays/masked.py index 31274232e2525..326d2f4c87468 100644 --- a/pandas/core/arrays/masked.py +++ b/pandas/core/arrays/masked.py @@ -19,7 +19,7 @@ from pandas.core import nanops from pandas.core.algorithms import factorize_array, take from pandas.core.array_algos import masked_reductions -from pandas.core.arrays import ExtensionArray, ExtensionOpsMixin +from pandas.core.arrays import ArrayFunctionMixin, ExtensionArray, ExtensionOpsMixin from pandas.core.indexers import check_array_indexer if TYPE_CHECKING: @@ -52,7 +52,7 @@ def construct_array_type(cls) -> Type["BaseMaskedArray"]: raise NotImplementedError -class BaseMaskedArray(ExtensionArray, ExtensionOpsMixin): +class BaseMaskedArray(ExtensionArray, ExtensionOpsMixin, ArrayFunctionMixin): """ Base class for masked arrays (which use _data and _mask to store the data). diff --git a/pandas/core/arrays/sparse/array.py b/pandas/core/arrays/sparse/array.py index 853f7bb0b0d81..02696aadffed2 100644 --- a/pandas/core/arrays/sparse/array.py +++ b/pandas/core/arrays/sparse/array.py @@ -40,7 +40,7 @@ from pandas.core.dtypes.missing import isna, na_value_for_dtype, notna import pandas.core.algorithms as algos -from pandas.core.arrays import ExtensionArray, ExtensionOpsMixin +from pandas.core.arrays import ArrayFunctionMixin, ExtensionArray, ExtensionOpsMixin from pandas.core.arrays.sparse.dtype import SparseDtype from pandas.core.base import PandasObject import pandas.core.common as com @@ -195,7 +195,7 @@ def _wrap_result(name, data, sparse_index, fill_value, dtype=None): ) -class SparseArray(PandasObject, ExtensionArray, ExtensionOpsMixin): +class SparseArray(PandasObject, ExtensionArray, ExtensionOpsMixin, ArrayFunctionMixin): """ An ExtensionArray for storing sparse data. diff --git a/pandas/core/dtypes/concat.py b/pandas/core/dtypes/concat.py index 60fd959701821..010058c9940aa 100644 --- a/pandas/core/dtypes/concat.py +++ b/pandas/core/dtypes/concat.py @@ -17,6 +17,7 @@ from pandas.core.dtypes.generic import ABCCategoricalIndex, ABCRangeIndex, ABCSeries from pandas.core.arrays import ExtensionArray +from pandas.core.arrays.base import implements from pandas.core.arrays.sparse import SparseArray from pandas.core.construction import array @@ -98,6 +99,7 @@ def _cast_to_common_type(arr: ArrayLike, dtype: DtypeObj) -> ArrayLike: return arr.astype(dtype, copy=False) +@implements(np.concatenate) def concat_compat(to_concat, axis: int = 0): """ provide concatenation of an array of arrays each of which is a single @@ -145,11 +147,15 @@ def is_nonempty(x) -> bool: target_dtype = find_common_type([x.dtype for x in to_concat]) to_concat = [_cast_to_common_type(arr, target_dtype) for arr in to_concat] - if isinstance(to_concat[0], ExtensionArray): + if isinstance(to_concat[0], ExtensionArray) and axis == 0: cls = type(to_concat[0]) return cls._concat_same_type(to_concat) else: - return np.concatenate(to_concat) + to_concat = [ + arr.to_numpy() if isinstance(arr, ExtensionArray) else arr + for arr in to_concat + ] + return np.concatenate(to_concat, axis=axis) elif _contains_datetime or "timedelta" in typs: return _concat_datetime(to_concat, axis=axis, typs=typs) diff --git a/pandas/tests/extension/base/__init__.py b/pandas/tests/extension/base/__init__.py index 323cb843b2d74..97a2763f5ab7b 100644 --- a/pandas/tests/extension/base/__init__.py +++ b/pandas/tests/extension/base/__init__.py @@ -50,6 +50,7 @@ class TestMyDtype(BaseDtypeTests): from .io import BaseParsingTests # noqa from .methods import BaseMethodsTests # noqa from .missing import BaseMissingTests # noqa +from .numpy_array_functions import BaseNumpyArrayFunctionTests # noqa from .ops import ( # noqa BaseArithmeticOpsTests, BaseComparisonOpsTests, diff --git a/pandas/tests/extension/base/base.py b/pandas/tests/extension/base/base.py index 97d8e7c66dbdb..67e7c1809da7a 100644 --- a/pandas/tests/extension/base/base.py +++ b/pandas/tests/extension/base/base.py @@ -19,3 +19,7 @@ def assert_frame_equal(cls, left, right, *args, **kwargs): @classmethod def assert_extension_array_equal(cls, left, right, *args, **kwargs): return tm.assert_extension_array_equal(left, right, *args, **kwargs) + + @classmethod + def assert_numpy_array_equal(cls, left, right, *args, **kwargs): + return tm.assert_numpy_array_equal(left, right, *args, **kwargs) diff --git a/pandas/tests/extension/base/numpy_array_functions.py b/pandas/tests/extension/base/numpy_array_functions.py new file mode 100644 index 0000000000000..f1d38b246bf95 --- /dev/null +++ b/pandas/tests/extension/base/numpy_array_functions.py @@ -0,0 +1,32 @@ +import numpy as np +import pytest + +import pandas.util._test_decorators as td + +import pandas as pd + +from .base import BaseExtensionTests + + +@td.skip_if_no_nep18 +class BaseNumpyArrayFunctionTests(BaseExtensionTests): + def test_concatenate(self, data): + expected = pd.array(list(data) * 3, dtype=data.dtype) + + result = np.concatenate([data] * 3) + self.assert_extension_array_equal(result, expected) + + result = np.concatenate([data] * 3, axis=0) + self.assert_extension_array_equal(result, expected) + + msg = "axis 1 is out of bounds for array of dimension 1" + with pytest.raises(np.AxisError, match=msg): + np.concatenate([data] * 3, axis=1) + + def test_ndim(self, data): + assert np.ndim(data) == 1 + + def test_vstack(self, data): + expected = np.array([data.to_numpy()] * 2) + result = np.vstack([data, data]) + self.assert_numpy_array_equal(result, expected) diff --git a/pandas/tests/extension/test_boolean.py b/pandas/tests/extension/test_boolean.py index 8acbeaf0b8170..638aed59637fc 100644 --- a/pandas/tests/extension/test_boolean.py +++ b/pandas/tests/extension/test_boolean.py @@ -387,3 +387,7 @@ class TestUnaryOps(base.BaseUnaryOpsTests): class TestParsing(base.BaseParsingTests): pass + + +class TestNumpyArrayFunctions(base.BaseNumpyArrayFunctionTests): + pass diff --git a/pandas/tests/extension/test_categorical.py b/pandas/tests/extension/test_categorical.py index f7b572a70073a..b9061aff5cde1 100644 --- a/pandas/tests/extension/test_categorical.py +++ b/pandas/tests/extension/test_categorical.py @@ -268,3 +268,7 @@ def test_not_equal_with_na(self, categories): class TestParsing(base.BaseParsingTests): pass + + +class TestNumpyArrayFunctions(base.BaseNumpyArrayFunctionTests): + pass diff --git a/pandas/tests/extension/test_datetime.py b/pandas/tests/extension/test_datetime.py index e026809f7e611..413adba040048 100644 --- a/pandas/tests/extension/test_datetime.py +++ b/pandas/tests/extension/test_datetime.py @@ -221,3 +221,7 @@ class TestGroupby(BaseDatetimeTests, base.BaseGroupbyTests): class TestPrinting(BaseDatetimeTests, base.BasePrintingTests): pass + + +class TestNumpyArrayFunctions(base.BaseNumpyArrayFunctionTests): + pass diff --git a/pandas/tests/extension/test_integer.py b/pandas/tests/extension/test_integer.py index 725533765ca2c..1c67b573c8fa4 100644 --- a/pandas/tests/extension/test_integer.py +++ b/pandas/tests/extension/test_integer.py @@ -255,3 +255,7 @@ class TestPrinting(base.BasePrintingTests): class TestParsing(base.BaseParsingTests): pass + + +class TestNumpyArrayFunctions(base.BaseNumpyArrayFunctionTests): + pass diff --git a/pandas/tests/extension/test_interval.py b/pandas/tests/extension/test_interval.py index 2411f6cfbd936..db63177d5b56f 100644 --- a/pandas/tests/extension/test_interval.py +++ b/pandas/tests/extension/test_interval.py @@ -164,3 +164,7 @@ def test_EA_types(self, engine, data): expected_msg = r".*must implement _from_sequence_of_strings.*" with pytest.raises(NotImplementedError, match=expected_msg): super().test_EA_types(engine, data) + + +class TestNumpyArrayFunctions(base.BaseNumpyArrayFunctionTests): + pass diff --git a/pandas/tests/extension/test_numpy.py b/pandas/tests/extension/test_numpy.py index bbfaacae1b444..1246729de2887 100644 --- a/pandas/tests/extension/test_numpy.py +++ b/pandas/tests/extension/test_numpy.py @@ -502,3 +502,8 @@ def test_setitem_loc_iloc_slice(self, data): @skip_nested class TestParsing(BaseNumPyTests, base.BaseParsingTests): pass + + +@skip_nested +class TestNumpyArrayFunctions(base.BaseNumpyArrayFunctionTests): + pass diff --git a/pandas/tests/extension/test_period.py b/pandas/tests/extension/test_period.py index 817881e00fa99..089075f324f94 100644 --- a/pandas/tests/extension/test_period.py +++ b/pandas/tests/extension/test_period.py @@ -172,3 +172,7 @@ class TestParsing(BasePeriodTests, base.BaseParsingTests): @pytest.mark.parametrize("engine", ["c", "python"]) def test_EA_types(self, engine, data): super().test_EA_types(engine, data) + + +class TestNumpyArrayFunctions(base.BaseNumpyArrayFunctionTests): + pass diff --git a/pandas/tests/extension/test_sparse.py b/pandas/tests/extension/test_sparse.py index d11cfd219a443..b1529fbca6cff 100644 --- a/pandas/tests/extension/test_sparse.py +++ b/pandas/tests/extension/test_sparse.py @@ -425,3 +425,7 @@ def test_EA_types(self, engine, data): expected_msg = r".*must implement _from_sequence_of_strings.*" with pytest.raises(NotImplementedError, match=expected_msg): super().test_EA_types(engine, data) + + +class TestNumpyArrayFunctions(base.BaseNumpyArrayFunctionTests): + pass diff --git a/pandas/tests/extension/test_string.py b/pandas/tests/extension/test_string.py index 27a157d2127f6..11788f71482e1 100644 --- a/pandas/tests/extension/test_string.py +++ b/pandas/tests/extension/test_string.py @@ -121,3 +121,7 @@ class TestPrinting(base.BasePrintingTests): class TestGroupBy(base.BaseGroupbyTests): pass + + +class TestNumpyArrayFunctions(base.BaseNumpyArrayFunctionTests): + pass diff --git a/pandas/util/_test_decorators.py b/pandas/util/_test_decorators.py index 0e8f6b933cd97..56777b580c4d6 100644 --- a/pandas/util/_test_decorators.py +++ b/pandas/util/_test_decorators.py @@ -33,6 +33,7 @@ def test_foo(): from pandas.compat import IS64, is_platform_windows from pandas.compat._optional import import_optional_dependency +from pandas.compat.numpy import IS_NEP18_ACTIVE from pandas.core.computation.expressions import NUMEXPR_INSTALLED, USE_NUMEXPR @@ -197,6 +198,9 @@ def skip_if_no(package: str, min_version: Optional[str] = None): not USE_NUMEXPR, reason=f"numexpr enabled->{USE_NUMEXPR}, installed->{NUMEXPR_INSTALLED}", ) +skip_if_no_nep18 = pytest.mark.skipif( + not IS_NEP18_ACTIVE, reason="NEP-18 support is not available in NumPy" +) # TODO: return type, _pytest.mark.structures.MarkDecorator is not public