From 01f7a866955b97ec6e50ca4eb44cfd0d078240e5 Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 25 Nov 2020 10:24:52 -0800 Subject: [PATCH 1/4] ENH: NDArrayBackedExtensionArray.__array_function__ --- pandas/core/arrays/_mixins.py | 42 +++++++++ pandas/tests/arrays/test_ndarray_backed.py | 105 +++++++++++++++++++++ 2 files changed, 147 insertions(+) create mode 100644 pandas/tests/arrays/test_ndarray_backed.py diff --git a/pandas/core/arrays/_mixins.py b/pandas/core/arrays/_mixins.py index 5cc6525dc3c9b..4669a7503d242 100644 --- a/pandas/core/arrays/_mixins.py +++ b/pandas/core/arrays/_mixins.py @@ -304,6 +304,48 @@ def __repr__(self) -> str: # ------------------------------------------------------------------------ # __array_function__ methods + def __array_function__(self, func, types, args, kwargs): + + if not args: + # I dont think this is possible is it? + raise NotImplementedError + + if args[0] is self: + + if func in [np.delete, np.repeat, np.atleast_2d]: + res_data = func(self._ndarray, *args[1:], **kwargs) + return self._from_backing_data(res_data) + + # TODO: do we need to convert args to kwargs to ensure nv checks + # are correct? + if func is np.amin: + return self.min(*args[1:], **kwargs) + if func is np.amax: + return self.max(*args[1:], **kwargs) + + if func is np.sum: + # Need to do explicitly otherise np.sum(TimedeltaArray) + # doesnt wrap in Timedelta. + return self.sum(*args[1:], **kwargs) + + if func is np.argsort: + if len(args) > 1: + # try to make sure that we are passing kwargs along correclty + raise NotImplementedError + return self.argsort(*args[1:], **kwargs) + + if not any(x is self for x in args): + # e.g. np.conatenate we get args[0] is a tuple containing self + largs = list(args) + for i, arg in enumerate(largs): + if isinstance(arg, (list, tuple)): + arg = type(arg)(x if x is not self else np.asarray(x) for x in arg) + largs[i] = arg + args = tuple(largs) + + args = [x if x is not self else np.asarray(x) for x in args] + return func(*args, **kwargs) + def putmask(self, mask, value): """ Analogue to np.putmask(self, mask, value) diff --git a/pandas/tests/arrays/test_ndarray_backed.py b/pandas/tests/arrays/test_ndarray_backed.py new file mode 100644 index 0000000000000..db7c2344cedc2 --- /dev/null +++ b/pandas/tests/arrays/test_ndarray_backed.py @@ -0,0 +1,105 @@ +""" +Tests for EA subclasses subclassing NDArrayBackedExtensionArray +""" + +import numpy as np +import pytest + +from pandas import date_range +import pandas._testing as tm +from pandas.core.arrays import Categorical, PandasArray + + +class ArrayFunctionTests: + # Tests for subclasses that do not explicitly support 2D yet. + def test_delete_no_axis(self, array): + # with no axis, operates on flattened version + result = np.delete(array, 1) + + backing = np.delete(array._ndarray.ravel(), 1) + expected = array._from_backing_data(backing) + tm.assert_equal(result, expected) + + def test_repeat(self, array): + result = np.repeat(array, 2) + + backing = np.repeat(array._ndarray.ravel(), 2) + expected = array._from_backing_data(backing) + tm.assert_equal(result, expected) + + +class ArrayFunctionTests2D(ArrayFunctionTests): + @pytest.mark.parametrize("axis", [0, 1]) + def test_delete_axis(self, array, axis): + result = np.delete(array, 1, axis=axis) + if axis == 0: + assert result.shape == (array.shape[0] - 1, array.shape[1]) + else: + assert result.shape == (array.shape[0], array.shape[1] - 1) + + backing = np.delete(array._ndarray, 1, axis=axis) + expected = array._from_backing_data(backing) + tm.assert_equal(result, expected) + + # axis as an arg instead of as a kwarg + result = np.delete(array, 1, axis) + tm.assert_equal(result, expected) + + @pytest.mark.parametrize("axis", [0, 1]) + def test_repeat_axis(self, array, axis): + result = np.repeat(array, 2, axis=axis) + + backing = np.repeat(array._ndarray, 2, axis=axis) + expected = array._from_backing_data(backing) + tm.assert_equal(result, expected) + + # axis as an arg instead of a kwarg + result = np.repeat(array, 2, axis) + tm.assert_equal(result, expected) + + def test_atleast_2d(self, array): + result = np.atleast_2d(array) + + assert result.ndim >= 2 + + if array.ndim == 1: + assert result.shape == (1, array.size) + else: + assert result.shape == array.shape + + +class TestDatetimeArray(ArrayFunctionTests2D): + @pytest.fixture(params=[1, 2]) + def array(self): + dti = date_range("1994-05-12", periods=12, tz="US/Pacific") + dta = dti._data.reshape(3, 4) + return dta + + +class TestTimedeltaArray(ArrayFunctionTests2D): + @pytest.fixture + def array(self): + dti = date_range("1994-05-12", periods=12, tz="US/Pacific") + dta = dti._data.reshape(3, 4) + return dta - dta[0, 0] + + +class TestPeriodArray(ArrayFunctionTests2D): + @pytest.fixture + def array(self): + dti = date_range("1994-05-12", periods=12) + pa = dti._data.to_period("D") + return pa.reshape(3, 4) + + +class TestPandasArray(ArrayFunctionTests): + @pytest.fixture + def array(self): + return PandasArray(np.arange(12)) + + +class TestCategorical(ArrayFunctionTests): + @pytest.fixture + def array(self): + dti = date_range("1994-05-12", periods=12, tz="US/Pacific") + return Categorical(dti) From 084b38041b0001ec1d00e772d70f179f3dab9c0d Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 25 Nov 2020 13:44:46 -0800 Subject: [PATCH 2/4] NEP18 check --- pandas/compat/numpy/__init__.py | 16 ++++++++++++++++ pandas/tests/arrays/test_ndarray_backed.py | 7 +++++++ 2 files changed, 23 insertions(+) diff --git a/pandas/compat/numpy/__init__.py b/pandas/compat/numpy/__init__.py index a2444b7ba5a0d..671ec653d6484 100644 --- a/pandas/compat/numpy/__init__.py +++ b/pandas/compat/numpy/__init__.py @@ -62,9 +62,25 @@ def np_array_datetime64_compat(arr, *args, **kwargs): return np.array(arr, *args, **kwargs) +def _is_nep18_active(): + # copied from dask.array.utils + + class A: + def __array_function__(self, *args, **kwargs): + return True + + try: + return np.concatenate([A()]) + except ValueError: + return False + + +IS_NEP18_ACTIVE = _is_nep18_active() + __all__ = [ "np", "_np_version", "np_version_under1p17", "is_numpy_dev", + "IS_NEP18_ACTIVE", ] diff --git a/pandas/tests/arrays/test_ndarray_backed.py b/pandas/tests/arrays/test_ndarray_backed.py index db7c2344cedc2..03c23eb01863c 100644 --- a/pandas/tests/arrays/test_ndarray_backed.py +++ b/pandas/tests/arrays/test_ndarray_backed.py @@ -5,10 +5,17 @@ import numpy as np import pytest +from pandas.compat.numpy import IS_NEP18_ACTIVE + from pandas import date_range import pandas._testing as tm from pandas.core.arrays import Categorical, PandasArray +pytestmark = pytest.mark.skipif( + not IS_NEP18_ACTIVE, + reason="__array_function__ is not enabled by default until numpy 1.17", +) + class ArrayFunctionTests: # Tests for subclasses that do not explicitly support 2D yet. From c967d60114cd9e3b90ed4434123dc5bd906d8af9 Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 25 Nov 2020 15:32:48 -0800 Subject: [PATCH 3/4] no-args cases --- pandas/core/arrays/_mixins.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/pandas/core/arrays/_mixins.py b/pandas/core/arrays/_mixins.py index 4669a7503d242..16f2b3da76699 100644 --- a/pandas/core/arrays/_mixins.py +++ b/pandas/core/arrays/_mixins.py @@ -305,12 +305,29 @@ def __repr__(self) -> str: # __array_function__ methods def __array_function__(self, func, types, args, kwargs): + for x in types: + if not issubclass(x, (np.ndarray, NDArrayBackedExtensionArray)): + return NotImplemented if not args: - # I dont think this is possible is it? - raise NotImplementedError - - if args[0] is self: + # TODO: if this fails, are we bound for a RecursionError? + for key, value in kwargs.items(): + if value is self: + # See if we can treat self as the first arg + import inspect + + sig = inspect.signature(func) + params = sig.parameters + first_argname = next(iter(params)) + if first_argname == key: + args = (value,) + del kwargs[key] + break + else: + kwargs[key] = np.asarray(self) + break + + if args and args[0] is self: if func in [np.delete, np.repeat, np.atleast_2d]: res_data = func(self._ndarray, *args[1:], **kwargs) From c08778ffbf8ded6e3f7c063a60f09753e89a2c41 Mon Sep 17 00:00:00 2001 From: Brock Date: Thu, 26 Nov 2020 07:49:23 -0800 Subject: [PATCH 4/4] mypy fixup --- pandas/core/arrays/_mixins.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/pandas/core/arrays/_mixins.py b/pandas/core/arrays/_mixins.py index 16f2b3da76699..f448d08f09234 100644 --- a/pandas/core/arrays/_mixins.py +++ b/pandas/core/arrays/_mixins.py @@ -336,14 +336,17 @@ def __array_function__(self, func, types, args, kwargs): # TODO: do we need to convert args to kwargs to ensure nv checks # are correct? if func is np.amin: - return self.min(*args[1:], **kwargs) + # error: "NDArrayBackedExtensionArray" has no attribute "min" + return self.min(*args[1:], **kwargs) # type:ignore[attr-defined] if func is np.amax: - return self.max(*args[1:], **kwargs) + # error: "NDArrayBackedExtensionArray" has no attribute "max" + return self.max(*args[1:], **kwargs) # type:ignore[attr-defined] if func is np.sum: # Need to do explicitly otherise np.sum(TimedeltaArray) # doesnt wrap in Timedelta. - return self.sum(*args[1:], **kwargs) + # error: "NDArrayBackedExtensionArray" has no attribute "sum" + return self.sum(*args[1:], **kwargs) # type:ignore[attr-defined] if func is np.argsort: if len(args) > 1: