diff --git a/pandas/compat/numpy/__init__.py b/pandas/compat/numpy/__init__.py index a2444b7ba5a0d..671ec653d6484 100644 --- a/pandas/compat/numpy/__init__.py +++ b/pandas/compat/numpy/__init__.py @@ -62,9 +62,25 @@ def np_array_datetime64_compat(arr, *args, **kwargs): return np.array(arr, *args, **kwargs) +def _is_nep18_active(): + # copied from dask.array.utils + + class A: + def __array_function__(self, *args, **kwargs): + return True + + try: + return np.concatenate([A()]) + except ValueError: + return False + + +IS_NEP18_ACTIVE = _is_nep18_active() + __all__ = [ "np", "_np_version", "np_version_under1p17", "is_numpy_dev", + "IS_NEP18_ACTIVE", ] diff --git a/pandas/core/arrays/_mixins.py b/pandas/core/arrays/_mixins.py index 5cc6525dc3c9b..f448d08f09234 100644 --- a/pandas/core/arrays/_mixins.py +++ b/pandas/core/arrays/_mixins.py @@ -304,6 +304,68 @@ def __repr__(self) -> str: # ------------------------------------------------------------------------ # __array_function__ methods + def __array_function__(self, func, types, args, kwargs): + for x in types: + if not issubclass(x, (np.ndarray, NDArrayBackedExtensionArray)): + return NotImplemented + + if not args: + # TODO: if this fails, are we bound for a RecursionError? + for key, value in kwargs.items(): + if value is self: + # See if we can treat self as the first arg + import inspect + + sig = inspect.signature(func) + params = sig.parameters + first_argname = next(iter(params)) + if first_argname == key: + args = (value,) + del kwargs[key] + break + else: + kwargs[key] = np.asarray(self) + break + + if args and args[0] is self: + + if func in [np.delete, np.repeat, np.atleast_2d]: + res_data = func(self._ndarray, *args[1:], **kwargs) + return self._from_backing_data(res_data) + + # TODO: do we need to convert args to kwargs to ensure nv checks + # are correct? + if func is np.amin: + # error: "NDArrayBackedExtensionArray" has no attribute "min" + return self.min(*args[1:], **kwargs) # type:ignore[attr-defined] + if func is np.amax: + # error: "NDArrayBackedExtensionArray" has no attribute "max" + return self.max(*args[1:], **kwargs) # type:ignore[attr-defined] + + if func is np.sum: + # Need to do explicitly otherise np.sum(TimedeltaArray) + # doesnt wrap in Timedelta. + # error: "NDArrayBackedExtensionArray" has no attribute "sum" + return self.sum(*args[1:], **kwargs) # type:ignore[attr-defined] + + if func is np.argsort: + if len(args) > 1: + # try to make sure that we are passing kwargs along correclty + raise NotImplementedError + return self.argsort(*args[1:], **kwargs) + + if not any(x is self for x in args): + # e.g. np.conatenate we get args[0] is a tuple containing self + largs = list(args) + for i, arg in enumerate(largs): + if isinstance(arg, (list, tuple)): + arg = type(arg)(x if x is not self else np.asarray(x) for x in arg) + largs[i] = arg + args = tuple(largs) + + args = [x if x is not self else np.asarray(x) for x in args] + return func(*args, **kwargs) + def putmask(self, mask, value): """ Analogue to np.putmask(self, mask, value) diff --git a/pandas/tests/arrays/test_ndarray_backed.py b/pandas/tests/arrays/test_ndarray_backed.py new file mode 100644 index 0000000000000..03c23eb01863c --- /dev/null +++ b/pandas/tests/arrays/test_ndarray_backed.py @@ -0,0 +1,112 @@ +""" +Tests for EA subclasses subclassing NDArrayBackedExtensionArray +""" + +import numpy as np +import pytest + +from pandas.compat.numpy import IS_NEP18_ACTIVE + +from pandas import date_range +import pandas._testing as tm +from pandas.core.arrays import Categorical, PandasArray + +pytestmark = pytest.mark.skipif( + not IS_NEP18_ACTIVE, + reason="__array_function__ is not enabled by default until numpy 1.17", +) + + +class ArrayFunctionTests: + # Tests for subclasses that do not explicitly support 2D yet. + def test_delete_no_axis(self, array): + # with no axis, operates on flattened version + result = np.delete(array, 1) + + backing = np.delete(array._ndarray.ravel(), 1) + expected = array._from_backing_data(backing) + tm.assert_equal(result, expected) + + def test_repeat(self, array): + result = np.repeat(array, 2) + + backing = np.repeat(array._ndarray.ravel(), 2) + expected = array._from_backing_data(backing) + tm.assert_equal(result, expected) + + +class ArrayFunctionTests2D(ArrayFunctionTests): + @pytest.mark.parametrize("axis", [0, 1]) + def test_delete_axis(self, array, axis): + result = np.delete(array, 1, axis=axis) + if axis == 0: + assert result.shape == (array.shape[0] - 1, array.shape[1]) + else: + assert result.shape == (array.shape[0], array.shape[1] - 1) + + backing = np.delete(array._ndarray, 1, axis=axis) + expected = array._from_backing_data(backing) + tm.assert_equal(result, expected) + + # axis as an arg instead of as a kwarg + result = np.delete(array, 1, axis) + tm.assert_equal(result, expected) + + @pytest.mark.parametrize("axis", [0, 1]) + def test_repeat_axis(self, array, axis): + result = np.repeat(array, 2, axis=axis) + + backing = np.repeat(array._ndarray, 2, axis=axis) + expected = array._from_backing_data(backing) + tm.assert_equal(result, expected) + + # axis as an arg instead of a kwarg + result = np.repeat(array, 2, axis) + tm.assert_equal(result, expected) + + def test_atleast_2d(self, array): + result = np.atleast_2d(array) + + assert result.ndim >= 2 + + if array.ndim == 1: + assert result.shape == (1, array.size) + else: + assert result.shape == array.shape + + +class TestDatetimeArray(ArrayFunctionTests2D): + @pytest.fixture(params=[1, 2]) + def array(self): + dti = date_range("1994-05-12", periods=12, tz="US/Pacific") + dta = dti._data.reshape(3, 4) + return dta + + +class TestTimedeltaArray(ArrayFunctionTests2D): + @pytest.fixture + def array(self): + dti = date_range("1994-05-12", periods=12, tz="US/Pacific") + dta = dti._data.reshape(3, 4) + return dta - dta[0, 0] + + +class TestPeriodArray(ArrayFunctionTests2D): + @pytest.fixture + def array(self): + dti = date_range("1994-05-12", periods=12) + pa = dti._data.to_period("D") + return pa.reshape(3, 4) + + +class TestPandasArray(ArrayFunctionTests): + @pytest.fixture + def array(self): + return PandasArray(np.arange(12)) + + +class TestCategorical(ArrayFunctionTests): + @pytest.fixture + def array(self): + dti = date_range("1994-05-12", periods=12, tz="US/Pacific") + return Categorical(dti)