Skip to content

BUG: retain EA dtypes in DataFrame __pos__, __neg__ #43883

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 5, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion doc/source/whatsnew/v1.4.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ Other enhancements
- Attempting to write into a file in missing parent directory with :meth:`DataFrame.to_csv`, :meth:`DataFrame.to_html`, :meth:`DataFrame.to_excel`, :meth:`DataFrame.to_feather`, :meth:`DataFrame.to_parquet`, :meth:`DataFrame.to_stata`, :meth:`DataFrame.to_json`, :meth:`DataFrame.to_pickle`, and :meth:`DataFrame.to_xml` now explicitly mentions missing parent directory, the same is true for :class:`Series` counterparts (:issue:`24306`)
- :meth:`IntegerArray.all` , :meth:`IntegerArray.any`, :meth:`FloatingArray.any`, and :meth:`FloatingArray.all` use Kleene logic (:issue:`41967`)
- Added support for nullable boolean and integer types in :meth:`DataFrame.to_stata`, :class:`~pandas.io.stata.StataWriter`, :class:`~pandas.io.stata.StataWriter117`, and :class:`~pandas.io.stata.StataWriterUTF8` (:issue:`40855`)
-
- :meth:`DataFrame.__pos__`, :meth:`DataFrame.__neg__` now retain ``ExtensionDtype`` dtypes (:issue:`43883`)


.. ---------------------------------------------------------------------------

Expand Down
41 changes: 34 additions & 7 deletions pandas/_libs/ops_dispatch.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ DISPATCHED_UFUNCS = {
"or",
"xor",
"and",
"neg",
"pos",
"abs",
}
UNARY_UFUNCS = {
"neg",
"pos",
"abs",
}
UFUNC_ALIASES = {
"subtract": "sub",
Expand All @@ -36,6 +44,9 @@ UFUNC_ALIASES = {
"bitwise_or": "or",
"bitwise_and": "and",
"bitwise_xor": "xor",
"negative": "neg",
"absolute": "abs",
"positive": "pos",
}

# For op(., Array) -> Array.__r{op}__
Expand Down Expand Up @@ -80,15 +91,31 @@ def maybe_dispatch_ufunc_to_dunder_op(
def not_implemented(*args, **kwargs):
return NotImplemented

if (method == "__call__"
and op_name in DISPATCHED_UFUNCS
and kwargs.get("out") is None):
if isinstance(inputs[0], type(self)):
if kwargs or ufunc.nin > 2:
return NotImplemented

if method == "__call__" and op_name in DISPATCHED_UFUNCS:

if inputs[0] is self:
name = f"__{op_name}__"
return getattr(self, name, not_implemented)(inputs[1])
else:
meth = getattr(self, name, not_implemented)

if op_name in UNARY_UFUNCS:
assert len(inputs) == 1
return meth()

return meth(inputs[1])

elif inputs[1] is self:
name = REVERSED_NAMES.get(op_name, f"__r{op_name}__")
result = getattr(self, name, not_implemented)(inputs[0])

meth = getattr(self, name, not_implemented)
result = meth(inputs[0])
return result

else:
# should not be reached, but covering our bases
return NotImplemented

else:
return NotImplemented
9 changes: 9 additions & 0 deletions pandas/core/arrays/numpy_.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,15 @@ def to_numpy(
def __invert__(self) -> PandasArray:
return type(self)(~self._ndarray)

def __neg__(self) -> PandasArray:
return type(self)(-self._ndarray)

def __pos__(self) -> PandasArray:
return type(self)(+self._ndarray)

def __abs__(self) -> PandasArray:
return type(self)(abs(self._ndarray))

def _cmp_method(self, other, op):
if isinstance(other, PandasArray):
other = other._ndarray
Expand Down
47 changes: 19 additions & 28 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
to_offset,
)
from pandas._typing import (
ArrayLike,
Axis,
CompressionOptions,
Dtype,
Expand Down Expand Up @@ -90,7 +91,6 @@
is_list_like,
is_number,
is_numeric_dtype,
is_object_dtype,
is_re_compilable,
is_scalar,
is_timedelta64_dtype,
Expand Down Expand Up @@ -1495,36 +1495,27 @@ def equals(self, other: object) -> bool_t:

@final
def __neg__(self):
values = self._values
if is_bool_dtype(values):
arr = operator.inv(values)
elif (
is_numeric_dtype(values)
or is_timedelta64_dtype(values)
or is_object_dtype(values)
):
arr = operator.neg(values)
else:
raise TypeError(f"Unary negative expects numeric dtype, not {values.dtype}")
return self.__array_wrap__(arr)
def blk_func(values: ArrayLike):
if is_bool_dtype(values.dtype):
return operator.inv(values)
else:
return operator.neg(values)

new_data = self._mgr.apply(blk_func)
res = self._constructor(new_data)
return res.__finalize__(self, method="__neg__")

@final
def __pos__(self):
values = self._values
if is_bool_dtype(values):
arr = values
elif (
is_numeric_dtype(values)
or is_timedelta64_dtype(values)
or is_object_dtype(values)
):
arr = operator.pos(values)
else:
raise TypeError(
"Unary plus expects bool, numeric, timedelta, "
f"or object dtype, not {values.dtype}"
)
return self.__array_wrap__(arr)
def blk_func(values: ArrayLike):
if is_bool_dtype(values.dtype):
return values.copy()
else:
return operator.pos(values)

new_data = self._mgr.apply(blk_func)
res = self._constructor(new_data)
return res.__finalize__(self, method="__neg__")

@final
def __invert__(self):
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/arithmetic/test_datetime64.py
Original file line number Diff line number Diff line change
Expand Up @@ -1878,7 +1878,7 @@ def test_datetime64_ops_nat(self):

# subtraction
tm.assert_series_equal(-NaT + datetime_series, nat_series_dtype_timestamp)
msg = "Unary negative expects"
msg = "bad operand type for unary -: 'DatetimeArray'"
with pytest.raises(TypeError, match=msg):
-single_nat_dtype_datetime + datetime_series

Expand Down
11 changes: 8 additions & 3 deletions pandas/tests/arrays/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,12 +198,17 @@ def test_validate_reduction_keyword_args():
# Ops


def test_ufunc():
@pytest.mark.parametrize("ufunc", [np.abs, np.negative, np.positive])
def test_ufunc_unary(ufunc):
arr = PandasArray(np.array([-1.0, 0.0, 1.0]))
result = np.abs(arr)
expected = PandasArray(np.abs(arr._ndarray))
result = ufunc(arr)
expected = PandasArray(ufunc(arr._ndarray))
tm.assert_extension_array_equal(result, expected)


def test_ufunc():
arr = PandasArray(np.array([-1.0, 0.0, 1.0]))

r1, r2 = np.divmod(arr, np.add(arr, 2))
e1, e2 = np.divmod(arr._ndarray, np.add(arr._ndarray, 2))
e1 = PandasArray(e1)
Expand Down
49 changes: 47 additions & 2 deletions pandas/tests/frame/test_unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_neg_object(self, df, expected):
def test_neg_raises(self, df):
msg = (
"bad operand type for unary -: 'str'|"
r"Unary negative expects numeric dtype, not datetime64\[ns\]"
r"bad operand type for unary -: 'DatetimeArray'"
)
with pytest.raises(TypeError, match=msg):
(-df)
Expand Down Expand Up @@ -116,8 +116,53 @@ def test_pos_object(self, df):
"df", [pd.DataFrame({"a": pd.to_datetime(["2017-01-22", "1970-01-01"])})]
)
def test_pos_raises(self, df):
msg = "Unary plus expects .* dtype, not datetime64\\[ns\\]"
msg = r"bad operand type for unary \+: 'DatetimeArray'"
with pytest.raises(TypeError, match=msg):
(+df)
with pytest.raises(TypeError, match=msg):
(+df["a"])

def test_unary_nullable(self):
df = pd.DataFrame(
{
"a": pd.array([1, -2, 3, pd.NA], dtype="Int64"),
"b": pd.array([4.0, -5.0, 6.0, pd.NA], dtype="Float32"),
"c": pd.array([True, False, False, pd.NA], dtype="boolean"),
# include numpy bool to make sure bool-vs-boolean behavior
# is consistent in non-NA locations
"d": np.array([True, False, False, True]),
}
)

result = +df
res_ufunc = np.positive(df)
expected = df
# TODO: assert that we have copies?
tm.assert_frame_equal(result, expected)
tm.assert_frame_equal(res_ufunc, expected)

result = -df
res_ufunc = np.negative(df)
expected = pd.DataFrame(
{
"a": pd.array([-1, 2, -3, pd.NA], dtype="Int64"),
"b": pd.array([-4.0, 5.0, -6.0, pd.NA], dtype="Float32"),
"c": pd.array([False, True, True, pd.NA], dtype="boolean"),
"d": np.array([False, True, True, False]),
}
)
tm.assert_frame_equal(result, expected)
tm.assert_frame_equal(res_ufunc, expected)

result = abs(df)
res_ufunc = np.abs(df)
expected = pd.DataFrame(
{
"a": pd.array([1, 2, 3, pd.NA], dtype="Int64"),
"b": pd.array([4.0, 5.0, 6.0, pd.NA], dtype="Float32"),
"c": pd.array([True, False, False, pd.NA], dtype="boolean"),
"d": np.array([True, False, False, True]),
}
)
tm.assert_frame_equal(result, expected)
tm.assert_frame_equal(res_ufunc, expected)