Skip to content

BUG: cumulative functions with ea dtype not handling NA correctly and casting to object #39483

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v1.3.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ ExtensionArray

- Bug in :meth:`DataFrame.where` when ``other`` is a :class:`Series` with :class:`ExtensionArray` dtype (:issue:`38729`)
- Fixed bug where :meth:`Series.idxmax`, :meth:`Series.idxmin` and ``argmax/min`` fail when the underlying data is :class:`ExtensionArray` (:issue:`32749`, :issue:`33719`, :issue:`36566`)
-
- Bug in cumulative functions (``cumsum``, ``cumprod``, ``cummax`` and ``cummin``) with extension dtypes not handling ``NA`` correctly and returning object dtype (:issue:`39479`)

Other
^^^^^
Expand Down
18 changes: 17 additions & 1 deletion pandas/core/nanops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
is_bool_dtype,
is_complex,
is_datetime64_any_dtype,
is_extension_array_dtype,
is_float,
is_float_dtype,
is_integer,
Expand All @@ -32,7 +33,7 @@
from pandas.core.dtypes.dtypes import PeriodDtype
from pandas.core.dtypes.missing import isna, na_value_for_dtype, notna

from pandas.core.construction import extract_array
from pandas.core.construction import extract_array, sanitize_array

bn = import_optional_dependency("bottleneck", errors="warn")
_BOTTLENECK_INSTALLED = bn is not None
Expand Down Expand Up @@ -1728,6 +1729,21 @@ def na_accum_func(values: ArrayLike, accum_func, *, skipna: bool) -> ArrayLike:
result, dtype=orig_dtype
)

elif is_extension_array_dtype(values.dtype):
if is_integer_dtype(values.dtype) and np.isinf(mask_a):
mask_a = {
np.maximum.accumulate: np.iinfo(values.dtype.type).min,
np.minimum.accumulate: np.iinfo(values.dtype.type).max,
}[accum_func]

vals = values.copy()
mask = isna(vals)
mask_copy = np.copy(mask)
vals[mask] = mask_a
result = accum_func(vals, axis=0)
result[mask_copy] = mask_b
result = sanitize_array(result, None, values.dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sanitize_array is relatively heavy-weight. what cases is it handling?


elif skipna and not issubclass(values.dtype.type, (np.integer, np.bool_)):
vals = values.copy()
mask = isna(vals)
Expand Down
20 changes: 19 additions & 1 deletion pandas/tests/frame/test_cumulative.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
"""

import numpy as np
import pytest

from pandas import DataFrame, Series
from pandas import NA, DataFrame, Series
import pandas._testing as tm


Expand Down Expand Up @@ -133,3 +134,20 @@ def test_cumulative_ops_preserve_dtypes(self):
}
)
tm.assert_frame_equal(result, expected)

@pytest.mark.parametrize(
"func, exp",
[
("cumsum", [2, NA, 7, 6, 6]),
("cumprod", [2, NA, 10, -10, 0]),
("cummin", [2, NA, 2, -1, -1]),
("cummax", [2, NA, 5, 5, 5]),
],
)
@pytest.mark.parametrize("dtype", ["Float64", "Int64"])
def test_cummulative_ops_extension_dtype(self, frame_or_series, dtype, func, exp):
# GH#39479
obj = frame_or_series([2, np.nan, 5, -1, 0], dtype=dtype)
result = getattr(obj, func)()
expected = frame_or_series(exp, dtype=dtype)
tm.assert_equal(result, expected)