Skip to content

Commit a9d8019

Browse files
authored
BUG: retain EA dtypes in DataFrame __pos__, __neg__ (#43883)
1 parent 3494078 commit a9d8019

File tree

7 files changed

+120
-42
lines changed

7 files changed

+120
-42
lines changed

doc/source/whatsnew/v1.4.0.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,8 @@ Other enhancements
126126
- Attempting to write into a file in missing parent directory with :meth:`DataFrame.to_csv`, :meth:`DataFrame.to_html`, :meth:`DataFrame.to_excel`, :meth:`DataFrame.to_feather`, :meth:`DataFrame.to_parquet`, :meth:`DataFrame.to_stata`, :meth:`DataFrame.to_json`, :meth:`DataFrame.to_pickle`, and :meth:`DataFrame.to_xml` now explicitly mentions missing parent directory, the same is true for :class:`Series` counterparts (:issue:`24306`)
127127
- :meth:`IntegerArray.all` , :meth:`IntegerArray.any`, :meth:`FloatingArray.any`, and :meth:`FloatingArray.all` use Kleene logic (:issue:`41967`)
128128
- Added support for nullable boolean and integer types in :meth:`DataFrame.to_stata`, :class:`~pandas.io.stata.StataWriter`, :class:`~pandas.io.stata.StataWriter117`, and :class:`~pandas.io.stata.StataWriterUTF8` (:issue:`40855`)
129-
-
129+
- :meth:`DataFrame.__pos__`, :meth:`DataFrame.__neg__` now retain ``ExtensionDtype`` dtypes (:issue:`43883`)
130+
130131

131132
.. ---------------------------------------------------------------------------
132133

pandas/_libs/ops_dispatch.pyx

+34-7
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@ DISPATCHED_UFUNCS = {
1818
"or",
1919
"xor",
2020
"and",
21+
"neg",
22+
"pos",
23+
"abs",
24+
}
25+
UNARY_UFUNCS = {
26+
"neg",
27+
"pos",
28+
"abs",
2129
}
2230
UFUNC_ALIASES = {
2331
"subtract": "sub",
@@ -36,6 +44,9 @@ UFUNC_ALIASES = {
3644
"bitwise_or": "or",
3745
"bitwise_and": "and",
3846
"bitwise_xor": "xor",
47+
"negative": "neg",
48+
"absolute": "abs",
49+
"positive": "pos",
3950
}
4051

4152
# For op(., Array) -> Array.__r{op}__
@@ -80,15 +91,31 @@ def maybe_dispatch_ufunc_to_dunder_op(
8091
def not_implemented(*args, **kwargs):
8192
return NotImplemented
8293

83-
if (method == "__call__"
84-
and op_name in DISPATCHED_UFUNCS
85-
and kwargs.get("out") is None):
86-
if isinstance(inputs[0], type(self)):
94+
if kwargs or ufunc.nin > 2:
95+
return NotImplemented
96+
97+
if method == "__call__" and op_name in DISPATCHED_UFUNCS:
98+
99+
if inputs[0] is self:
87100
name = f"__{op_name}__"
88-
return getattr(self, name, not_implemented)(inputs[1])
89-
else:
101+
meth = getattr(self, name, not_implemented)
102+
103+
if op_name in UNARY_UFUNCS:
104+
assert len(inputs) == 1
105+
return meth()
106+
107+
return meth(inputs[1])
108+
109+
elif inputs[1] is self:
90110
name = REVERSED_NAMES.get(op_name, f"__r{op_name}__")
91-
result = getattr(self, name, not_implemented)(inputs[0])
111+
112+
meth = getattr(self, name, not_implemented)
113+
result = meth(inputs[0])
92114
return result
115+
116+
else:
117+
# should not be reached, but covering our bases
118+
return NotImplemented
119+
93120
else:
94121
return NotImplemented

pandas/core/arrays/numpy_.py

+9
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,15 @@ def to_numpy(
388388
def __invert__(self) -> PandasArray:
389389
return type(self)(~self._ndarray)
390390

391+
def __neg__(self) -> PandasArray:
392+
return type(self)(-self._ndarray)
393+
394+
def __pos__(self) -> PandasArray:
395+
return type(self)(+self._ndarray)
396+
397+
def __abs__(self) -> PandasArray:
398+
return type(self)(abs(self._ndarray))
399+
391400
def _cmp_method(self, other, op):
392401
if isinstance(other, PandasArray):
393402
other = other._ndarray

pandas/core/generic.py

+19-28
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
to_offset,
3737
)
3838
from pandas._typing import (
39+
ArrayLike,
3940
Axis,
4041
CompressionOptions,
4142
Dtype,
@@ -90,7 +91,6 @@
9091
is_list_like,
9192
is_number,
9293
is_numeric_dtype,
93-
is_object_dtype,
9494
is_re_compilable,
9595
is_scalar,
9696
is_timedelta64_dtype,
@@ -1495,36 +1495,27 @@ def equals(self, other: object) -> bool_t:
14951495

14961496
@final
14971497
def __neg__(self):
1498-
values = self._values
1499-
if is_bool_dtype(values):
1500-
arr = operator.inv(values)
1501-
elif (
1502-
is_numeric_dtype(values)
1503-
or is_timedelta64_dtype(values)
1504-
or is_object_dtype(values)
1505-
):
1506-
arr = operator.neg(values)
1507-
else:
1508-
raise TypeError(f"Unary negative expects numeric dtype, not {values.dtype}")
1509-
return self.__array_wrap__(arr)
1498+
def blk_func(values: ArrayLike):
1499+
if is_bool_dtype(values.dtype):
1500+
return operator.inv(values)
1501+
else:
1502+
return operator.neg(values)
1503+
1504+
new_data = self._mgr.apply(blk_func)
1505+
res = self._constructor(new_data)
1506+
return res.__finalize__(self, method="__neg__")
15101507

15111508
@final
15121509
def __pos__(self):
1513-
values = self._values
1514-
if is_bool_dtype(values):
1515-
arr = values
1516-
elif (
1517-
is_numeric_dtype(values)
1518-
or is_timedelta64_dtype(values)
1519-
or is_object_dtype(values)
1520-
):
1521-
arr = operator.pos(values)
1522-
else:
1523-
raise TypeError(
1524-
"Unary plus expects bool, numeric, timedelta, "
1525-
f"or object dtype, not {values.dtype}"
1526-
)
1527-
return self.__array_wrap__(arr)
1510+
def blk_func(values: ArrayLike):
1511+
if is_bool_dtype(values.dtype):
1512+
return values.copy()
1513+
else:
1514+
return operator.pos(values)
1515+
1516+
new_data = self._mgr.apply(blk_func)
1517+
res = self._constructor(new_data)
1518+
return res.__finalize__(self, method="__pos__")
15281519

15291520
@final
15301521
def __invert__(self):

pandas/tests/arithmetic/test_datetime64.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1878,7 +1878,7 @@ def test_datetime64_ops_nat(self):
18781878

18791879
# subtraction
18801880
tm.assert_series_equal(-NaT + datetime_series, nat_series_dtype_timestamp)
1881-
msg = "Unary negative expects"
1881+
msg = "bad operand type for unary -: 'DatetimeArray'"
18821882
with pytest.raises(TypeError, match=msg):
18831883
-single_nat_dtype_datetime + datetime_series
18841884

pandas/tests/arrays/test_numpy.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -198,12 +198,17 @@ def test_validate_reduction_keyword_args():
198198
# Ops
199199

200200

201-
def test_ufunc():
201+
@pytest.mark.parametrize("ufunc", [np.abs, np.negative, np.positive])
202+
def test_ufunc_unary(ufunc):
202203
arr = PandasArray(np.array([-1.0, 0.0, 1.0]))
203-
result = np.abs(arr)
204-
expected = PandasArray(np.abs(arr._ndarray))
204+
result = ufunc(arr)
205+
expected = PandasArray(ufunc(arr._ndarray))
205206
tm.assert_extension_array_equal(result, expected)
206207

208+
209+
def test_ufunc():
210+
arr = PandasArray(np.array([-1.0, 0.0, 1.0]))
211+
207212
r1, r2 = np.divmod(arr, np.add(arr, 2))
208213
e1, e2 = np.divmod(arr._ndarray, np.add(arr._ndarray, 2))
209214
e1 = PandasArray(e1)

pandas/tests/frame/test_unary.py

+47-2
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def test_neg_object(self, df, expected):
4949
def test_neg_raises(self, df):
5050
msg = (
5151
"bad operand type for unary -: 'str'|"
52-
r"Unary negative expects numeric dtype, not datetime64\[ns\]"
52+
r"bad operand type for unary -: 'DatetimeArray'"
5353
)
5454
with pytest.raises(TypeError, match=msg):
5555
(-df)
@@ -116,8 +116,53 @@ def test_pos_object(self, df):
116116
"df", [pd.DataFrame({"a": pd.to_datetime(["2017-01-22", "1970-01-01"])})]
117117
)
118118
def test_pos_raises(self, df):
119-
msg = "Unary plus expects .* dtype, not datetime64\\[ns\\]"
119+
msg = r"bad operand type for unary \+: 'DatetimeArray'"
120120
with pytest.raises(TypeError, match=msg):
121121
(+df)
122122
with pytest.raises(TypeError, match=msg):
123123
(+df["a"])
124+
125+
def test_unary_nullable(self):
126+
df = pd.DataFrame(
127+
{
128+
"a": pd.array([1, -2, 3, pd.NA], dtype="Int64"),
129+
"b": pd.array([4.0, -5.0, 6.0, pd.NA], dtype="Float32"),
130+
"c": pd.array([True, False, False, pd.NA], dtype="boolean"),
131+
# include numpy bool to make sure bool-vs-boolean behavior
132+
# is consistent in non-NA locations
133+
"d": np.array([True, False, False, True]),
134+
}
135+
)
136+
137+
result = +df
138+
res_ufunc = np.positive(df)
139+
expected = df
140+
# TODO: assert that we have copies?
141+
tm.assert_frame_equal(result, expected)
142+
tm.assert_frame_equal(res_ufunc, expected)
143+
144+
result = -df
145+
res_ufunc = np.negative(df)
146+
expected = pd.DataFrame(
147+
{
148+
"a": pd.array([-1, 2, -3, pd.NA], dtype="Int64"),
149+
"b": pd.array([-4.0, 5.0, -6.0, pd.NA], dtype="Float32"),
150+
"c": pd.array([False, True, True, pd.NA], dtype="boolean"),
151+
"d": np.array([False, True, True, False]),
152+
}
153+
)
154+
tm.assert_frame_equal(result, expected)
155+
tm.assert_frame_equal(res_ufunc, expected)
156+
157+
result = abs(df)
158+
res_ufunc = np.abs(df)
159+
expected = pd.DataFrame(
160+
{
161+
"a": pd.array([1, 2, 3, pd.NA], dtype="Int64"),
162+
"b": pd.array([4.0, 5.0, 6.0, pd.NA], dtype="Float32"),
163+
"c": pd.array([True, False, False, pd.NA], dtype="boolean"),
164+
"d": np.array([True, False, False, True]),
165+
}
166+
)
167+
tm.assert_frame_equal(result, expected)
168+
tm.assert_frame_equal(res_ufunc, expected)

0 commit comments

Comments
 (0)