Skip to content

ENH: Support mask in GroupBy.cumsum #48070

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Aug 18, 2022
7 changes: 4 additions & 3 deletions doc/source/whatsnew/v1.5.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ Other enhancements
- ``times`` argument in :class:`.ExponentialMovingWindow` now accepts ``np.timedelta64`` (:issue:`47003`)
- :class:`.DataError`, :class:`.SpecificationError`, :class:`.SettingWithCopyError`, :class:`.SettingWithCopyWarning`, :class:`.NumExprClobberingError`, :class:`.UndefinedVariableError`, :class:`.IndexingError`, :class:`.PyperclipException`, :class:`.PyperclipWindowsException`, :class:`.CSSWarning`, :class:`.PossibleDataLossError`, :class:`.ClosedFileError`, :class:`.IncompatibilityWarning`, :class:`.AttributeConflictWarning`, :class:`.DatabaseError, :class:`.PossiblePrecisionLoss, :class:`.ValueLabelTypeMismatch, :class:`.InvalidColumnName, and :class:`.CategoricalConversionWarning` are now exposed in ``pandas.errors`` (:issue:`27656`)
- Added ``check_like`` argument to :func:`testing.assert_series_equal` (:issue:`47247`)
- Add support for :meth:`GroupBy.ohlc` for extension array dtypes (:issue:`37493`)
- Add support for :meth:`.GroupBy.ohlc` for extension array dtypes (:issue:`37493`)
- Allow reading compressed SAS files with :func:`read_sas` (e.g., ``.sas7bdat.gz`` files)
- :func:`pandas.read_html` now supports extracting links from table cells (:issue:`13141`)
- :meth:`DatetimeIndex.astype` now supports casting timezone-naive indexes to ``datetime64[s]``, ``datetime64[ms]``, and ``datetime64[us]``, and timezone-aware indexes to the corresponding ``datetime64[unit, tzname]`` dtypes (:issue:`47579`)
Expand Down Expand Up @@ -1078,12 +1078,13 @@ Groupby/resample/rolling
- Bug when using ``engine="numba"`` would return the same jitted function when modifying ``engine_kwargs`` (:issue:`46086`)
- Bug in :meth:`.DataFrameGroupBy.transform` fails when ``axis=1`` and ``func`` is ``"first"`` or ``"last"`` (:issue:`45986`)
- Bug in :meth:`DataFrameGroupBy.cumsum` with ``skipna=False`` giving incorrect results (:issue:`46216`)
- Bug in :meth:`GroupBy.sum` with integer dtypes losing precision (:issue:`37493`)
- Bug in :meth:`.GroupBy.sum` and :meth:`.GroupBy.cumsum` with integer dtypes losing precision (:issue:`37493`)
- Bug in :meth:`.GroupBy.cumsum` with ``timedelta64[ns]`` dtype failing to recognize ``NaT`` as a null value (:issue:`46216`)
- Bug in :meth:`.GroupBy.cumsum` with integer dtypes causing overflows when sum was bigger than maximum of dtype (:issue:`37493`)
- Bug in :meth:`.GroupBy.cummin` and :meth:`.GroupBy.cummax` with nullable dtypes incorrectly altering the original data in place (:issue:`46220`)
- Bug in :meth:`DataFrame.groupby` raising error when ``None`` is in first level of :class:`MultiIndex` (:issue:`47348`)
- Bug in :meth:`.GroupBy.cummax` with ``int64`` dtype with leading value being the smallest possible int64 (:issue:`46382`)
- Bug in :meth:`GroupBy.cumprod` ``NaN`` influences calculation in different columns with ``skipna=False`` (:issue:`48064`)
- Bug in :meth:`.GroupBy.cumprod` ``NaN`` influences calculation in different columns with ``skipna=False`` (:issue:`48064`)
- Bug in :meth:`.GroupBy.max` with empty groups and ``uint64`` dtype incorrectly raising ``RuntimeError`` (:issue:`46408`)
- Bug in :meth:`.GroupBy.apply` would fail when ``func`` was a string and args or kwargs were supplied (:issue:`46479`)
- Bug in :meth:`SeriesGroupBy.apply` would incorrectly name its result when there was a unique group (:issue:`46369`)
Expand Down
6 changes: 4 additions & 2 deletions pandas/_libs/groupby.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@ def group_cumprod_float64(
skipna: bool = ...,
) -> None: ...
def group_cumsum(
out: np.ndarray, # numeric[:, ::1]
values: np.ndarray, # ndarray[numeric, ndim=2]
out: np.ndarray, # int64float_t[:, ::1]
values: np.ndarray, # ndarray[int64float_t, ndim=2]
labels: np.ndarray, # const int64_t[:]
ngroups: int,
is_datetimelike: bool,
skipna: bool = ...,
mask: np.ndarray | None = ...,
result_mask: np.ndarray | None = ...,
) -> None: ...
def group_shift_indexer(
out: np.ndarray, # int64_t[::1]
Expand Down
72 changes: 53 additions & 19 deletions pandas/_libs/groupby.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -206,15 +206,24 @@ def group_cumprod_float64(
accum[lab, j] = NaN


ctypedef fused int64float_t:
int64_t
uint64_t
float32_t
float64_t


@cython.boundscheck(False)
@cython.wraparound(False)
def group_cumsum(
numeric_t[:, ::1] out,
ndarray[numeric_t, ndim=2] values,
int64float_t[:, ::1] out,
ndarray[int64float_t, ndim=2] values,
const intp_t[::1] labels,
int ngroups,
bint is_datetimelike,
bint skipna=True,
const uint8_t[:, :] mask=None,
uint8_t[:, ::1] result_mask=None,
) -> None:
"""
Cumulative sum of columns of `values`, in row groups `labels`.
Expand All @@ -233,23 +242,33 @@ def group_cumsum(
True if `values` contains datetime-like entries.
skipna : bool
If true, ignore nans in `values`.
mask: np.ndarray[uint8], optional
Mask of values
result_mask: np.ndarray[int8], optional
Mask of out array

Notes
-----
This method modifies the `out` parameter, rather than returning an object.
"""
cdef:
Py_ssize_t i, j, N, K, size
numeric_t val, y, t, na_val
numeric_t[:, ::1] accum, compensation
int64float_t val, y, t, na_val
int64float_t[:, ::1] accum, compensation
uint8_t[:, ::1] accum_mask
intp_t lab
bint isna_entry, isna_prev = False
bint uses_mask = mask is not None

N, K = (<object>values).shape

if uses_mask:
accum_mask = np.zeros((ngroups, K), dtype="uint8")

accum = np.zeros((ngroups, K), dtype=np.asarray(values).dtype)
compensation = np.zeros((ngroups, K), dtype=np.asarray(values).dtype)

na_val = _get_na_val(<numeric_t>0, is_datetimelike)
na_val = _get_na_val(<int64float_t>0, is_datetimelike)

with nogil:
for i in range(N):
Expand All @@ -260,23 +279,45 @@ def group_cumsum(
for j in range(K):
val = values[i, j]

isna_entry = _treat_as_na(val, is_datetimelike)
if uses_mask:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than continue if uses_mask: checks we make that the outermost branch? Might help readability to keep the logic in two different branches rather than continued checks within one

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried that, but imo that reduces readability. The uses_mask is a simple if_else branch, if we move this outside, it is hard to see that the actual logic of the algorithm is the same in both branches (and we have to keep it consistent over time).

isna_entry = mask[i, j]
else:
isna_entry = _treat_as_na(val, is_datetimelike)

if not skipna:
isna_prev = _treat_as_na(accum[lab, j], is_datetimelike)
if uses_mask:
isna_prev = accum_mask[lab, j]
else:
isna_prev = _treat_as_na(accum[lab, j], is_datetimelike)

if isna_prev:
out[i, j] = na_val
if uses_mask:
result_mask[i, j] = True
# Be deterministic, out was initialized as empty
out[i, j] = 0
else:
out[i, j] = na_val
continue

if isna_entry:
out[i, j] = na_val

if uses_mask:
result_mask[i, j] = True
# Be deterministic, out was initialized as empty
out[i, j] = 0
else:
out[i, j] = na_val

if not skipna:
accum[lab, j] = na_val
if uses_mask:
accum_mask[lab, j] = True
else:
accum[lab, j] = na_val

else:
# For floats, use Kahan summation to reduce floating-point
# error (https://en.wikipedia.org/wiki/Kahan_summation_algorithm)
if numeric_t == float32_t or numeric_t == float64_t:
if int64float_t == float32_t or int64float_t == float64_t:
y = val - compensation[lab, j]
t = accum[lab, j] + y
compensation[lab, j] = t - accum[lab, j] - y
Expand Down Expand Up @@ -834,13 +875,6 @@ def group_mean(
out[i, j] = sumx[i, j] / count


ctypedef fused int64float_t:
float32_t
float64_t
int64_t
uint64_t


@cython.wraparound(False)
@cython.boundscheck(False)
def group_ohlc(
Expand Down Expand Up @@ -1070,7 +1104,7 @@ cdef numeric_t _get_na_val(numeric_t val, bint is_datetimelike):
elif numeric_t is int64_t and is_datetimelike:
na_val = NPY_NAT
else:
# Will not be used, but define to avoid uninitialized warning.
# Used in case of masks
na_val = 0
return na_val

Expand Down
3 changes: 2 additions & 1 deletion pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def __init__(self, kind: str, how: str, has_dropped_na: bool) -> None:
"rank",
"sum",
"ohlc",
"cumsum",
}

_cython_arity = {"ohlc": 4} # OHLC
Expand Down Expand Up @@ -226,7 +227,7 @@ def _get_cython_vals(self, values: np.ndarray) -> np.ndarray:
# result may still include NaN, so we have to cast
values = ensure_float64(values)

elif how in ["sum", "ohlc"]:
elif how in ["sum", "ohlc", "cumsum"]:
# Avoid overflow during group op
if values.dtype.kind == "i":
values = ensure_int64(values)
Expand Down
25 changes: 22 additions & 3 deletions pandas/tests/groupby/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2846,12 +2846,15 @@ def test_single_element_list_grouping():
values, _ = next(iter(df.groupby(["a"])))


def test_groupby_sum_avoid_casting_to_float():
@pytest.mark.parametrize("func", ["sum", "cumsum"])
def test_groupby_sum_avoid_casting_to_float(func):
# GH#37493
val = 922337203685477580
df = DataFrame({"a": 1, "b": [val]})
result = df.groupby("a").sum() - val
result = getattr(df.groupby("a"), func)() - val
expected = DataFrame({"b": [0]}, index=Index([1], name="a"))
if func == "cumsum":
expected = expected.reset_index(drop=True)
tm.assert_frame_equal(result, expected)


Expand All @@ -2868,7 +2871,7 @@ def test_groupby_sum_support_mask(any_numeric_ea_dtype):


@pytest.mark.parametrize("val, dtype", [(111, "int"), (222, "uint")])
def test_groupby_sum_overflow(val, dtype):
def test_groupby_overflow(val, dtype):
# GH#37493
df = DataFrame({"a": 1, "b": [val, val]}, dtype=f"{dtype}8")
result = df.groupby("a").sum()
Expand All @@ -2878,3 +2881,19 @@ def test_groupby_sum_overflow(val, dtype):
dtype=f"{dtype}64",
)
tm.assert_frame_equal(result, expected)

result = df.groupby("a").cumsum()
expected = DataFrame({"b": [val, val * 2]}, dtype=f"{dtype}64")
tm.assert_frame_equal(result, expected)


@pytest.mark.parametrize("skipna, val", [(True, 3), (False, pd.NA)])
def test_groupby_cumsum_mask(any_numeric_ea_dtype, skipna, val):
# GH#37493
df = DataFrame({"a": 1, "b": [1, pd.NA, 2]}, dtype=any_numeric_ea_dtype)
result = df.groupby("a").cumsum(skipna=skipna)
expected = DataFrame(
{"b": [1, pd.NA, val]},
dtype=any_numeric_ea_dtype,
)
tm.assert_frame_equal(result, expected)
5 changes: 3 additions & 2 deletions pandas/tests/groupby/test_libgroupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,10 @@ def _check_cython_group_transform_cumulative(pd_op, np_op, dtype):
tm.assert_numpy_array_equal(np_op(data), answer[:, 0], check_dtype=False)


def test_cython_group_transform_cumsum(any_real_numpy_dtype):
@pytest.mark.parametrize("np_dtype", ["int64", "uint64", "float32", "float64"])
def test_cython_group_transform_cumsum(np_dtype):
# see gh-4095
dtype = np.dtype(any_real_numpy_dtype).type
dtype = np.dtype(np_dtype).type
pd_op, np_op = group_cumsum, np.cumsum
_check_cython_group_transform_cumulative(pd_op, np_op, dtype)

Expand Down