-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
ENH: NDArrayBackedExtensionArray.__array_function__ #38068
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
01f7a86
21252e5
084b380
c967d60
c08778f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -304,6 +304,68 @@ def __repr__(self) -> str: | |
# ------------------------------------------------------------------------ | ||
# __array_function__ methods | ||
|
||
def __array_function__(self, func, types, args, kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. array_ufunc now exists in pandas/core/arraylike.py we should share as much as possible here (might mean breaking up the existing code). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the only shared code will be the checking of the types. We've standardized on a pattern like
And then ensuring that the set of types in |
||
for x in types: | ||
if not issubclass(x, (np.ndarray, NDArrayBackedExtensionArray)): | ||
return NotImplemented | ||
|
||
if not args: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you give a code example of this path? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
# TODO: if this fails, are we bound for a RecursionError? | ||
for key, value in kwargs.items(): | ||
if value is self: | ||
# See if we can treat self as the first arg | ||
import inspect | ||
|
||
sig = inspect.signature(func) | ||
params = sig.parameters | ||
first_argname = next(iter(params)) | ||
if first_argname == key: | ||
args = (value,) | ||
del kwargs[key] | ||
break | ||
else: | ||
kwargs[key] = np.asarray(self) | ||
break | ||
|
||
if args and args[0] is self: | ||
|
||
if func in [np.delete, np.repeat, np.atleast_2d]: | ||
res_data = func(self._ndarray, *args[1:], **kwargs) | ||
return self._from_backing_data(res_data) | ||
Comment on lines
+333
to
+334
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How do we know that there's 1 output argument from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we don't, thats why we're implementing only a specific handful of functions here |
||
|
||
# TODO: do we need to convert args to kwargs to ensure nv checks | ||
# are correct? | ||
if func is np.amin: | ||
# error: "NDArrayBackedExtensionArray" has no attribute "min" | ||
return self.min(*args[1:], **kwargs) # type:ignore[attr-defined] | ||
if func is np.amax: | ||
# error: "NDArrayBackedExtensionArray" has no attribute "max" | ||
return self.max(*args[1:], **kwargs) # type:ignore[attr-defined] | ||
|
||
if func is np.sum: | ||
# Need to do explicitly otherise np.sum(TimedeltaArray) | ||
# doesnt wrap in Timedelta. | ||
# error: "NDArrayBackedExtensionArray" has no attribute "sum" | ||
return self.sum(*args[1:], **kwargs) # type:ignore[attr-defined] | ||
|
||
if func is np.argsort: | ||
if len(args) > 1: | ||
# try to make sure that we are passing kwargs along correclty | ||
raise NotImplementedError | ||
return self.argsort(*args[1:], **kwargs) | ||
|
||
if not any(x is self for x in args): | ||
# e.g. np.conatenate we get args[0] is a tuple containing self | ||
largs = list(args) | ||
for i, arg in enumerate(largs): | ||
if isinstance(arg, (list, tuple)): | ||
arg = type(arg)(x if x is not self else np.asarray(x) for x in arg) | ||
largs[i] = arg | ||
args = tuple(largs) | ||
|
||
args = [x if x is not self else np.asarray(x) for x in args] | ||
return func(*args, **kwargs) | ||
|
||
def putmask(self, mask, value): | ||
""" | ||
Analogue to np.putmask(self, mask, value) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
""" | ||
Tests for EA subclasses subclassing NDArrayBackedExtensionArray | ||
""" | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from pandas.compat.numpy import IS_NEP18_ACTIVE | ||
|
||
from pandas import date_range | ||
import pandas._testing as tm | ||
from pandas.core.arrays import Categorical, PandasArray | ||
|
||
pytestmark = pytest.mark.skipif( | ||
not IS_NEP18_ACTIVE, | ||
reason="__array_function__ is not enabled by default until numpy 1.17", | ||
) | ||
|
||
|
||
class ArrayFunctionTests: | ||
# Tests for subclasses that do not explicitly support 2D yet. | ||
def test_delete_no_axis(self, array): | ||
# with no axis, operates on flattened version | ||
result = np.delete(array, 1) | ||
|
||
backing = np.delete(array._ndarray.ravel(), 1) | ||
expected = array._from_backing_data(backing) | ||
tm.assert_equal(result, expected) | ||
|
||
def test_repeat(self, array): | ||
result = np.repeat(array, 2) | ||
|
||
backing = np.repeat(array._ndarray.ravel(), 2) | ||
expected = array._from_backing_data(backing) | ||
tm.assert_equal(result, expected) | ||
|
||
|
||
class ArrayFunctionTests2D(ArrayFunctionTests): | ||
@pytest.mark.parametrize("axis", [0, 1]) | ||
def test_delete_axis(self, array, axis): | ||
result = np.delete(array, 1, axis=axis) | ||
if axis == 0: | ||
assert result.shape == (array.shape[0] - 1, array.shape[1]) | ||
else: | ||
assert result.shape == (array.shape[0], array.shape[1] - 1) | ||
|
||
backing = np.delete(array._ndarray, 1, axis=axis) | ||
expected = array._from_backing_data(backing) | ||
tm.assert_equal(result, expected) | ||
|
||
# axis as an arg instead of as a kwarg | ||
result = np.delete(array, 1, axis) | ||
tm.assert_equal(result, expected) | ||
|
||
@pytest.mark.parametrize("axis", [0, 1]) | ||
def test_repeat_axis(self, array, axis): | ||
result = np.repeat(array, 2, axis=axis) | ||
|
||
backing = np.repeat(array._ndarray, 2, axis=axis) | ||
expected = array._from_backing_data(backing) | ||
tm.assert_equal(result, expected) | ||
|
||
# axis as an arg instead of a kwarg | ||
result = np.repeat(array, 2, axis) | ||
tm.assert_equal(result, expected) | ||
|
||
def test_atleast_2d(self, array): | ||
result = np.atleast_2d(array) | ||
|
||
assert result.ndim >= 2 | ||
|
||
if array.ndim == 1: | ||
assert result.shape == (1, array.size) | ||
else: | ||
assert result.shape == array.shape | ||
|
||
|
||
class TestDatetimeArray(ArrayFunctionTests2D): | ||
@pytest.fixture(params=[1, 2]) | ||
def array(self): | ||
dti = date_range("1994-05-12", periods=12, tz="US/Pacific") | ||
dta = dti._data.reshape(3, 4) | ||
return dta | ||
|
||
|
||
class TestTimedeltaArray(ArrayFunctionTests2D): | ||
@pytest.fixture | ||
def array(self): | ||
dti = date_range("1994-05-12", periods=12, tz="US/Pacific") | ||
dta = dti._data.reshape(3, 4) | ||
return dta - dta[0, 0] | ||
|
||
|
||
class TestPeriodArray(ArrayFunctionTests2D): | ||
@pytest.fixture | ||
def array(self): | ||
dti = date_range("1994-05-12", periods=12) | ||
pa = dti._data.to_period("D") | ||
return pa.reshape(3, 4) | ||
|
||
|
||
class TestPandasArray(ArrayFunctionTests): | ||
@pytest.fixture | ||
def array(self): | ||
return PandasArray(np.arange(12)) | ||
|
||
|
||
class TestCategorical(ArrayFunctionTests): | ||
@pytest.fixture | ||
def array(self): | ||
dti = date_range("1994-05-12", periods=12, tz="US/Pacific") | ||
return Categorical(dti) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We require NumPy>1.16.0, so this can maybe be simplified to
See https://numpy.org/neps/nep-0018-array-function-protocol.html#implementation.