Skip to content

Commit 01f7a86

Browse files
committed
ENH: NDArrayBackedExtensionArray.__array_function__
1 parent 61f67b6 commit 01f7a86

File tree

2 files changed

+147
-0
lines changed

2 files changed

+147
-0
lines changed

pandas/core/arrays/_mixins.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,48 @@ def __repr__(self) -> str:
304304
# ------------------------------------------------------------------------
305305
# __array_function__ methods
306306

307+
def __array_function__(self, func, types, args, kwargs):
308+
309+
if not args:
310+
# I dont think this is possible is it?
311+
raise NotImplementedError
312+
313+
if args[0] is self:
314+
315+
if func in [np.delete, np.repeat, np.atleast_2d]:
316+
res_data = func(self._ndarray, *args[1:], **kwargs)
317+
return self._from_backing_data(res_data)
318+
319+
# TODO: do we need to convert args to kwargs to ensure nv checks
320+
# are correct?
321+
if func is np.amin:
322+
return self.min(*args[1:], **kwargs)
323+
if func is np.amax:
324+
return self.max(*args[1:], **kwargs)
325+
326+
if func is np.sum:
327+
# Need to do explicitly otherise np.sum(TimedeltaArray)
328+
# doesnt wrap in Timedelta.
329+
return self.sum(*args[1:], **kwargs)
330+
331+
if func is np.argsort:
332+
if len(args) > 1:
333+
# try to make sure that we are passing kwargs along correclty
334+
raise NotImplementedError
335+
return self.argsort(*args[1:], **kwargs)
336+
337+
if not any(x is self for x in args):
338+
# e.g. np.conatenate we get args[0] is a tuple containing self
339+
largs = list(args)
340+
for i, arg in enumerate(largs):
341+
if isinstance(arg, (list, tuple)):
342+
arg = type(arg)(x if x is not self else np.asarray(x) for x in arg)
343+
largs[i] = arg
344+
args = tuple(largs)
345+
346+
args = [x if x is not self else np.asarray(x) for x in args]
347+
return func(*args, **kwargs)
348+
307349
def putmask(self, mask, value):
308350
"""
309351
Analogue to np.putmask(self, mask, value)
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
"""
2+
Tests for EA subclasses subclassing NDArrayBackedExtensionArray
3+
"""
4+
5+
import numpy as np
6+
import pytest
7+
8+
from pandas import date_range
9+
import pandas._testing as tm
10+
from pandas.core.arrays import Categorical, PandasArray
11+
12+
13+
class ArrayFunctionTests:
14+
# Tests for subclasses that do not explicitly support 2D yet.
15+
def test_delete_no_axis(self, array):
16+
# with no axis, operates on flattened version
17+
result = np.delete(array, 1)
18+
19+
backing = np.delete(array._ndarray.ravel(), 1)
20+
expected = array._from_backing_data(backing)
21+
tm.assert_equal(result, expected)
22+
23+
def test_repeat(self, array):
24+
result = np.repeat(array, 2)
25+
26+
backing = np.repeat(array._ndarray.ravel(), 2)
27+
expected = array._from_backing_data(backing)
28+
tm.assert_equal(result, expected)
29+
30+
31+
class ArrayFunctionTests2D(ArrayFunctionTests):
32+
@pytest.mark.parametrize("axis", [0, 1])
33+
def test_delete_axis(self, array, axis):
34+
result = np.delete(array, 1, axis=axis)
35+
if axis == 0:
36+
assert result.shape == (array.shape[0] - 1, array.shape[1])
37+
else:
38+
assert result.shape == (array.shape[0], array.shape[1] - 1)
39+
40+
backing = np.delete(array._ndarray, 1, axis=axis)
41+
expected = array._from_backing_data(backing)
42+
tm.assert_equal(result, expected)
43+
44+
# axis as an arg instead of as a kwarg
45+
result = np.delete(array, 1, axis)
46+
tm.assert_equal(result, expected)
47+
48+
@pytest.mark.parametrize("axis", [0, 1])
49+
def test_repeat_axis(self, array, axis):
50+
result = np.repeat(array, 2, axis=axis)
51+
52+
backing = np.repeat(array._ndarray, 2, axis=axis)
53+
expected = array._from_backing_data(backing)
54+
tm.assert_equal(result, expected)
55+
56+
# axis as an arg instead of a kwarg
57+
result = np.repeat(array, 2, axis)
58+
tm.assert_equal(result, expected)
59+
60+
def test_atleast_2d(self, array):
61+
result = np.atleast_2d(array)
62+
63+
assert result.ndim >= 2
64+
65+
if array.ndim == 1:
66+
assert result.shape == (1, array.size)
67+
else:
68+
assert result.shape == array.shape
69+
70+
71+
class TestDatetimeArray(ArrayFunctionTests2D):
72+
@pytest.fixture(params=[1, 2])
73+
def array(self):
74+
dti = date_range("1994-05-12", periods=12, tz="US/Pacific")
75+
dta = dti._data.reshape(3, 4)
76+
return dta
77+
78+
79+
class TestTimedeltaArray(ArrayFunctionTests2D):
80+
@pytest.fixture
81+
def array(self):
82+
dti = date_range("1994-05-12", periods=12, tz="US/Pacific")
83+
dta = dti._data.reshape(3, 4)
84+
return dta - dta[0, 0]
85+
86+
87+
class TestPeriodArray(ArrayFunctionTests2D):
88+
@pytest.fixture
89+
def array(self):
90+
dti = date_range("1994-05-12", periods=12)
91+
pa = dti._data.to_period("D")
92+
return pa.reshape(3, 4)
93+
94+
95+
class TestPandasArray(ArrayFunctionTests):
96+
@pytest.fixture
97+
def array(self):
98+
return PandasArray(np.arange(12))
99+
100+
101+
class TestCategorical(ArrayFunctionTests):
102+
@pytest.fixture
103+
def array(self):
104+
dti = date_range("1994-05-12", periods=12, tz="US/Pacific")
105+
return Categorical(dti)

0 commit comments

Comments
 (0)