Skip to content

Commit cdc7b4a

Browse files
jbrockmendelrhshadrach
authored andcommitted
ENH: implement ExtensionArray.__array_ufunc__ (pandas-dev#43899)
1 parent f157d4d commit cdc7b4a

File tree

8 files changed

+96
-7
lines changed

8 files changed

+96
-7
lines changed

doc/source/whatsnew/v1.4.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,7 @@ Sparse
522522

523523
ExtensionArray
524524
^^^^^^^^^^^^^^
525-
-
525+
- NumPy ufuncs ``np.abs``, ``np.positive``, ``np.negative`` now correctly preserve dtype when called on ExtensionArrays that implement ``__abs__, __pos__, __neg__``, respectively. In particular this is fixed for :class:`TimedeltaArray` (:issue:`43899`)
526526
-
527527

528528
Styler

pandas/core/arraylike.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,8 @@ def reconstruct(result):
371371
# * len(inputs) > 1 is doable when we know that we have
372372
# aligned blocks / dtypes.
373373
inputs = tuple(np.asarray(x) for x in inputs)
374+
# Note: we can't use default_array_ufunc here bc reindexing means
375+
# that `self` may not be among `inputs`
374376
result = getattr(ufunc, method)(*inputs, **kwargs)
375377
elif self.ndim == 1:
376378
# ufunc(series, ...)
@@ -387,7 +389,7 @@ def reconstruct(result):
387389
else:
388390
# otherwise specific ufunc methods (eg np.<ufunc>.accumulate(..))
389391
# Those can have an axis keyword and thus can't be called block-by-block
390-
result = getattr(ufunc, method)(np.asarray(inputs[0]), **kwargs)
392+
result = default_array_ufunc(inputs[0], ufunc, method, *inputs, **kwargs)
391393

392394
result = reconstruct(result)
393395
return result
@@ -452,3 +454,19 @@ def _assign_where(out, result, where) -> None:
452454
out[:] = result
453455
else:
454456
np.putmask(out, where, result)
457+
458+
459+
def default_array_ufunc(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):
460+
"""
461+
Fallback to the behavior we would get if we did not define __array_ufunc__.
462+
463+
Notes
464+
-----
465+
We are assuming that `self` is among `inputs`.
466+
"""
467+
if not any(x is self for x in inputs):
468+
raise NotImplementedError
469+
470+
new_inputs = [x if x is not self else np.asarray(x) for x in inputs]
471+
472+
return getattr(ufunc, method)(*new_inputs, **kwargs)

pandas/core/arrays/base.py

+15
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
from pandas.core.dtypes.missing import isna
6666

6767
from pandas.core import (
68+
arraylike,
6869
missing,
6970
ops,
7071
)
@@ -1366,6 +1367,20 @@ def _empty(cls, shape: Shape, dtype: ExtensionDtype):
13661367
)
13671368
return result
13681369

1370+
def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):
1371+
if any(
1372+
isinstance(other, (ABCSeries, ABCIndex, ABCDataFrame)) for other in inputs
1373+
):
1374+
return NotImplemented
1375+
1376+
result = arraylike.maybe_dispatch_ufunc_to_dunder_op(
1377+
self, ufunc, method, *inputs, **kwargs
1378+
)
1379+
if result is not NotImplemented:
1380+
return result
1381+
1382+
return arraylike.default_array_ufunc(self, ufunc, method, *inputs, **kwargs)
1383+
13691384

13701385
class ExtensionOpsMixin:
13711386
"""

pandas/core/arrays/boolean.py

+3
Original file line numberDiff line numberDiff line change
@@ -604,3 +604,6 @@ def _maybe_mask_result(self, result, mask, other, op_name: str):
604604
else:
605605
result[mask] = np.nan
606606
return result
607+
608+
def __abs__(self):
609+
return self.copy()

pandas/tests/arrays/boolean/test_ops.py

+7
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,10 @@ def test_invert(self):
1818
{"A": expected, "B": [False, True, True]}, index=["a", "b", "c"]
1919
)
2020
tm.assert_frame_equal(result, expected)
21+
22+
def test_abs(self):
23+
# matching numpy behavior, abs is the identity function
24+
arr = pd.array([True, False, None], dtype="boolean")
25+
result = abs(arr)
26+
27+
tm.assert_extension_array_equal(result, arr)

pandas/tests/arrays/test_timedeltas.py

+19
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,19 @@ def test_abs(self):
9090
result = abs(arr)
9191
tm.assert_timedelta_array_equal(result, expected)
9292

93+
result2 = np.abs(arr)
94+
tm.assert_timedelta_array_equal(result2, expected)
95+
96+
def test_pos(self):
97+
vals = np.array([-3600 * 10 ** 9, "NaT", 7200 * 10 ** 9], dtype="m8[ns]")
98+
arr = TimedeltaArray(vals)
99+
100+
result = +arr
101+
tm.assert_timedelta_array_equal(result, arr)
102+
103+
result2 = np.positive(arr)
104+
tm.assert_timedelta_array_equal(result2, arr)
105+
93106
def test_neg(self):
94107
vals = np.array([-3600 * 10 ** 9, "NaT", 7200 * 10 ** 9], dtype="m8[ns]")
95108
arr = TimedeltaArray(vals)
@@ -100,6 +113,9 @@ def test_neg(self):
100113
result = -arr
101114
tm.assert_timedelta_array_equal(result, expected)
102115

116+
result2 = np.negative(arr)
117+
tm.assert_timedelta_array_equal(result2, expected)
118+
103119
def test_neg_freq(self):
104120
tdi = pd.timedelta_range("2 Days", periods=4, freq="H")
105121
arr = TimedeltaArray(tdi, freq=tdi.freq)
@@ -108,3 +124,6 @@ def test_neg_freq(self):
108124

109125
result = -arr
110126
tm.assert_timedelta_array_equal(result, expected)
127+
128+
result2 = np.negative(arr)
129+
tm.assert_timedelta_array_equal(result2, expected)

pandas/tests/extension/arrow/test_bool.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,10 @@ def test_view(self, data):
5454
# __setitem__ does not work, so we only have a smoke-test
5555
data.view()
5656

57-
@pytest.mark.xfail(raises=AssertionError, reason="Not implemented yet")
57+
@pytest.mark.xfail(
58+
raises=AttributeError,
59+
reason="__eq__ incorrectly returns bool instead of ndarray[bool]",
60+
)
5861
def test_contains(self, data, data_missing):
5962
super().test_contains(data, data_missing)
6063

pandas/tests/extension/base/ops.py

+28-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import numpy as np
34
import pytest
45

56
import pandas as pd
@@ -128,11 +129,13 @@ class BaseComparisonOpsTests(BaseOpsUtil):
128129
"""Various Series and DataFrame comparison ops methods."""
129130

130131
def _compare_other(self, s, data, op_name, other):
132+
131133
op = self.get_op_from_name(op_name)
132-
if op_name == "__eq__":
133-
assert not op(s, other).all()
134-
elif op_name == "__ne__":
135-
assert op(s, other).all()
134+
if op_name in ["__eq__", "__ne__"]:
135+
# comparison should match point-wise comparisons
136+
result = op(s, other)
137+
expected = s.combine(other, op)
138+
self.assert_series_equal(result, expected)
136139

137140
else:
138141

@@ -182,3 +185,24 @@ def test_invert(self, data):
182185
result = ~s
183186
expected = pd.Series(~data, name="name")
184187
self.assert_series_equal(result, expected)
188+
189+
@pytest.mark.parametrize("ufunc", [np.positive, np.negative, np.abs])
190+
def test_unary_ufunc_dunder_equivalence(self, data, ufunc):
191+
# the dunder __pos__ works if and only if np.positive works,
192+
# same for __neg__/np.negative and __abs__/np.abs
193+
attr = {np.positive: "__pos__", np.negative: "__neg__", np.abs: "__abs__"}[
194+
ufunc
195+
]
196+
197+
exc = None
198+
try:
199+
result = getattr(data, attr)()
200+
except Exception as err:
201+
exc = err
202+
203+
# if __pos__ raised, then so should the ufunc
204+
with pytest.raises((type(exc), TypeError)):
205+
ufunc(data)
206+
else:
207+
alt = ufunc(data)
208+
self.assert_extension_array_equal(result, alt)

0 commit comments

Comments
 (0)