Skip to content

ENH: NDArrayBackedExtensionArray.__array_function__ #38068

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions pandas/compat/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,25 @@ def np_array_datetime64_compat(arr, *args, **kwargs):
return np.array(arr, *args, **kwargs)


def _is_nep18_active():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We require NumPy>1.16.0, so this can maybe be simplified to

IS_NEP18_ACTIVE = not os.environ.get("NUMPY_EXPERIMENTAL_ARRAY_FUNCTION", "1") == "0"

See https://numpy.org/neps/nep-0018-array-function-protocol.html#implementation.

# copied from dask.array.utils

class A:
def __array_function__(self, *args, **kwargs):
return True

try:
return np.concatenate([A()])
except ValueError:
return False


IS_NEP18_ACTIVE = _is_nep18_active()

__all__ = [
"np",
"_np_version",
"np_version_under1p17",
"is_numpy_dev",
"IS_NEP18_ACTIVE",
]
62 changes: 62 additions & 0 deletions pandas/core/arrays/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,68 @@ def __repr__(self) -> str:
# ------------------------------------------------------------------------
# __array_function__ methods

def __array_function__(self, func, types, args, kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

array_ufunc now exists in pandas/core/arraylike.py we should share as much as possible here (might mean breaking up the existing code).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the only shared code will be the checking of the types. We've standardized on a pattern like

class Foo:
    _HANDLED_TYPES = (...,)

And then ensuring that the set of types in types is a subset of _HANDLED_TYPES + (type(self),).

for x in types:
if not issubclass(x, (np.ndarray, NDArrayBackedExtensionArray)):
return NotImplemented

if not args:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you give a code example of this path?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np.delete(arr=np.arange(5), obj=4)

# TODO: if this fails, are we bound for a RecursionError?
for key, value in kwargs.items():
if value is self:
# See if we can treat self as the first arg
import inspect

sig = inspect.signature(func)
params = sig.parameters
first_argname = next(iter(params))
if first_argname == key:
args = (value,)
del kwargs[key]
break
else:
kwargs[key] = np.asarray(self)
break

if args and args[0] is self:

if func in [np.delete, np.repeat, np.atleast_2d]:
res_data = func(self._ndarray, *args[1:], **kwargs)
return self._from_backing_data(res_data)
Comment on lines +333 to +334
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we know that there's 1 output argument from func here? Do we have a func.nout or something like that to check?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't, thats why we're implementing only a specific handful of functions here


# TODO: do we need to convert args to kwargs to ensure nv checks
# are correct?
if func is np.amin:
# error: "NDArrayBackedExtensionArray" has no attribute "min"
return self.min(*args[1:], **kwargs) # type:ignore[attr-defined]
if func is np.amax:
# error: "NDArrayBackedExtensionArray" has no attribute "max"
return self.max(*args[1:], **kwargs) # type:ignore[attr-defined]

if func is np.sum:
# Need to do explicitly otherise np.sum(TimedeltaArray)
# doesnt wrap in Timedelta.
# error: "NDArrayBackedExtensionArray" has no attribute "sum"
return self.sum(*args[1:], **kwargs) # type:ignore[attr-defined]

if func is np.argsort:
if len(args) > 1:
# try to make sure that we are passing kwargs along correclty
raise NotImplementedError
return self.argsort(*args[1:], **kwargs)

if not any(x is self for x in args):
# e.g. np.conatenate we get args[0] is a tuple containing self
largs = list(args)
for i, arg in enumerate(largs):
if isinstance(arg, (list, tuple)):
arg = type(arg)(x if x is not self else np.asarray(x) for x in arg)
largs[i] = arg
args = tuple(largs)

args = [x if x is not self else np.asarray(x) for x in args]
return func(*args, **kwargs)

def putmask(self, mask, value):
"""
Analogue to np.putmask(self, mask, value)
Expand Down
112 changes: 112 additions & 0 deletions pandas/tests/arrays/test_ndarray_backed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""
Tests for EA subclasses subclassing NDArrayBackedExtensionArray
"""

import numpy as np
import pytest

from pandas.compat.numpy import IS_NEP18_ACTIVE

from pandas import date_range
import pandas._testing as tm
from pandas.core.arrays import Categorical, PandasArray

pytestmark = pytest.mark.skipif(
not IS_NEP18_ACTIVE,
reason="__array_function__ is not enabled by default until numpy 1.17",
)


class ArrayFunctionTests:
# Tests for subclasses that do not explicitly support 2D yet.
def test_delete_no_axis(self, array):
# with no axis, operates on flattened version
result = np.delete(array, 1)

backing = np.delete(array._ndarray.ravel(), 1)
expected = array._from_backing_data(backing)
tm.assert_equal(result, expected)

def test_repeat(self, array):
result = np.repeat(array, 2)

backing = np.repeat(array._ndarray.ravel(), 2)
expected = array._from_backing_data(backing)
tm.assert_equal(result, expected)


class ArrayFunctionTests2D(ArrayFunctionTests):
@pytest.mark.parametrize("axis", [0, 1])
def test_delete_axis(self, array, axis):
result = np.delete(array, 1, axis=axis)
if axis == 0:
assert result.shape == (array.shape[0] - 1, array.shape[1])
else:
assert result.shape == (array.shape[0], array.shape[1] - 1)

backing = np.delete(array._ndarray, 1, axis=axis)
expected = array._from_backing_data(backing)
tm.assert_equal(result, expected)

# axis as an arg instead of as a kwarg
result = np.delete(array, 1, axis)
tm.assert_equal(result, expected)

@pytest.mark.parametrize("axis", [0, 1])
def test_repeat_axis(self, array, axis):
result = np.repeat(array, 2, axis=axis)

backing = np.repeat(array._ndarray, 2, axis=axis)
expected = array._from_backing_data(backing)
tm.assert_equal(result, expected)

# axis as an arg instead of a kwarg
result = np.repeat(array, 2, axis)
tm.assert_equal(result, expected)

def test_atleast_2d(self, array):
result = np.atleast_2d(array)

assert result.ndim >= 2

if array.ndim == 1:
assert result.shape == (1, array.size)
else:
assert result.shape == array.shape


class TestDatetimeArray(ArrayFunctionTests2D):
@pytest.fixture(params=[1, 2])
def array(self):
dti = date_range("1994-05-12", periods=12, tz="US/Pacific")
dta = dti._data.reshape(3, 4)
return dta


class TestTimedeltaArray(ArrayFunctionTests2D):
@pytest.fixture
def array(self):
dti = date_range("1994-05-12", periods=12, tz="US/Pacific")
dta = dti._data.reshape(3, 4)
return dta - dta[0, 0]


class TestPeriodArray(ArrayFunctionTests2D):
@pytest.fixture
def array(self):
dti = date_range("1994-05-12", periods=12)
pa = dti._data.to_period("D")
return pa.reshape(3, 4)


class TestPandasArray(ArrayFunctionTests):
@pytest.fixture
def array(self):
return PandasArray(np.arange(12))


class TestCategorical(ArrayFunctionTests):
@pytest.fixture
def array(self):
dti = date_range("1994-05-12", periods=12, tz="US/Pacific")
return Categorical(dti)