Skip to content

Commit c19a4ad

Browse files
authored
ENH: Support mask in GroupBy.cumsum (#48070)
1 parent 157a65e commit c19a4ad

File tree

6 files changed

+88
-30
lines changed

6 files changed

+88
-30
lines changed

doc/source/whatsnew/v1.5.0.rst

+4-3
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ Other enhancements
287287
- ``times`` argument in :class:`.ExponentialMovingWindow` now accepts ``np.timedelta64`` (:issue:`47003`)
288288
- :class:`.DataError`, :class:`.SpecificationError`, :class:`.SettingWithCopyError`, :class:`.SettingWithCopyWarning`, :class:`.NumExprClobberingError`, :class:`.UndefinedVariableError`, :class:`.IndexingError`, :class:`.PyperclipException`, :class:`.PyperclipWindowsException`, :class:`.CSSWarning`, :class:`.PossibleDataLossError`, :class:`.ClosedFileError`, :class:`.IncompatibilityWarning`, :class:`.AttributeConflictWarning`, :class:`.DatabaseError, :class:`.PossiblePrecisionLoss, :class:`.ValueLabelTypeMismatch, :class:`.InvalidColumnName, and :class:`.CategoricalConversionWarning` are now exposed in ``pandas.errors`` (:issue:`27656`)
289289
- Added ``check_like`` argument to :func:`testing.assert_series_equal` (:issue:`47247`)
290-
- Add support for :meth:`GroupBy.ohlc` for extension array dtypes (:issue:`37493`)
290+
- Add support for :meth:`.GroupBy.ohlc` for extension array dtypes (:issue:`37493`)
291291
- Allow reading compressed SAS files with :func:`read_sas` (e.g., ``.sas7bdat.gz`` files)
292292
- :func:`pandas.read_html` now supports extracting links from table cells (:issue:`13141`)
293293
- :meth:`DatetimeIndex.astype` now supports casting timezone-naive indexes to ``datetime64[s]``, ``datetime64[ms]``, and ``datetime64[us]``, and timezone-aware indexes to the corresponding ``datetime64[unit, tzname]`` dtypes (:issue:`47579`)
@@ -1077,12 +1077,13 @@ Groupby/resample/rolling
10771077
- Bug when using ``engine="numba"`` would return the same jitted function when modifying ``engine_kwargs`` (:issue:`46086`)
10781078
- Bug in :meth:`.DataFrameGroupBy.transform` fails when ``axis=1`` and ``func`` is ``"first"`` or ``"last"`` (:issue:`45986`)
10791079
- Bug in :meth:`DataFrameGroupBy.cumsum` with ``skipna=False`` giving incorrect results (:issue:`46216`)
1080-
- Bug in :meth:`GroupBy.sum` with integer dtypes losing precision (:issue:`37493`)
1080+
- Bug in :meth:`.GroupBy.sum` and :meth:`.GroupBy.cumsum` with integer dtypes losing precision (:issue:`37493`)
10811081
- Bug in :meth:`.GroupBy.cumsum` with ``timedelta64[ns]`` dtype failing to recognize ``NaT`` as a null value (:issue:`46216`)
1082+
- Bug in :meth:`.GroupBy.cumsum` with integer dtypes causing overflows when sum was bigger than maximum of dtype (:issue:`37493`)
10821083
- Bug in :meth:`.GroupBy.cummin` and :meth:`.GroupBy.cummax` with nullable dtypes incorrectly altering the original data in place (:issue:`46220`)
10831084
- Bug in :meth:`DataFrame.groupby` raising error when ``None`` is in first level of :class:`MultiIndex` (:issue:`47348`)
10841085
- Bug in :meth:`.GroupBy.cummax` with ``int64`` dtype with leading value being the smallest possible int64 (:issue:`46382`)
1085-
- Bug in :meth:`GroupBy.cumprod` ``NaN`` influences calculation in different columns with ``skipna=False`` (:issue:`48064`)
1086+
- Bug in :meth:`.GroupBy.cumprod` ``NaN`` influences calculation in different columns with ``skipna=False`` (:issue:`48064`)
10861087
- Bug in :meth:`.GroupBy.max` with empty groups and ``uint64`` dtype incorrectly raising ``RuntimeError`` (:issue:`46408`)
10871088
- Bug in :meth:`.GroupBy.apply` would fail when ``func`` was a string and args or kwargs were supplied (:issue:`46479`)
10881089
- Bug in :meth:`SeriesGroupBy.apply` would incorrectly name its result when there was a unique group (:issue:`46369`)

pandas/_libs/groupby.pyi

+4-2
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@ def group_cumprod_float64(
2020
skipna: bool = ...,
2121
) -> None: ...
2222
def group_cumsum(
23-
out: np.ndarray, # numeric[:, ::1]
24-
values: np.ndarray, # ndarray[numeric, ndim=2]
23+
out: np.ndarray, # int64float_t[:, ::1]
24+
values: np.ndarray, # ndarray[int64float_t, ndim=2]
2525
labels: np.ndarray, # const int64_t[:]
2626
ngroups: int,
2727
is_datetimelike: bool,
2828
skipna: bool = ...,
29+
mask: np.ndarray | None = ...,
30+
result_mask: np.ndarray | None = ...,
2931
) -> None: ...
3032
def group_shift_indexer(
3133
out: np.ndarray, # int64_t[::1]

pandas/_libs/groupby.pyx

+53-19
Original file line numberDiff line numberDiff line change
@@ -206,15 +206,24 @@ def group_cumprod_float64(
206206
accum[lab, j] = NaN
207207

208208

209+
ctypedef fused int64float_t:
210+
int64_t
211+
uint64_t
212+
float32_t
213+
float64_t
214+
215+
209216
@cython.boundscheck(False)
210217
@cython.wraparound(False)
211218
def group_cumsum(
212-
numeric_t[:, ::1] out,
213-
ndarray[numeric_t, ndim=2] values,
219+
int64float_t[:, ::1] out,
220+
ndarray[int64float_t, ndim=2] values,
214221
const intp_t[::1] labels,
215222
int ngroups,
216223
bint is_datetimelike,
217224
bint skipna=True,
225+
const uint8_t[:, :] mask=None,
226+
uint8_t[:, ::1] result_mask=None,
218227
) -> None:
219228
"""
220229
Cumulative sum of columns of `values`, in row groups `labels`.
@@ -233,23 +242,33 @@ def group_cumsum(
233242
True if `values` contains datetime-like entries.
234243
skipna : bool
235244
If true, ignore nans in `values`.
245+
mask: np.ndarray[uint8], optional
246+
Mask of values
247+
result_mask: np.ndarray[int8], optional
248+
Mask of out array
236249

237250
Notes
238251
-----
239252
This method modifies the `out` parameter, rather than returning an object.
240253
"""
241254
cdef:
242255
Py_ssize_t i, j, N, K, size
243-
numeric_t val, y, t, na_val
244-
numeric_t[:, ::1] accum, compensation
256+
int64float_t val, y, t, na_val
257+
int64float_t[:, ::1] accum, compensation
258+
uint8_t[:, ::1] accum_mask
245259
intp_t lab
246260
bint isna_entry, isna_prev = False
261+
bint uses_mask = mask is not None
247262

248263
N, K = (<object>values).shape
264+
265+
if uses_mask:
266+
accum_mask = np.zeros((ngroups, K), dtype="uint8")
267+
249268
accum = np.zeros((ngroups, K), dtype=np.asarray(values).dtype)
250269
compensation = np.zeros((ngroups, K), dtype=np.asarray(values).dtype)
251270

252-
na_val = _get_na_val(<numeric_t>0, is_datetimelike)
271+
na_val = _get_na_val(<int64float_t>0, is_datetimelike)
253272

254273
with nogil:
255274
for i in range(N):
@@ -260,23 +279,45 @@ def group_cumsum(
260279
for j in range(K):
261280
val = values[i, j]
262281

263-
isna_entry = _treat_as_na(val, is_datetimelike)
282+
if uses_mask:
283+
isna_entry = mask[i, j]
284+
else:
285+
isna_entry = _treat_as_na(val, is_datetimelike)
264286

265287
if not skipna:
266-
isna_prev = _treat_as_na(accum[lab, j], is_datetimelike)
288+
if uses_mask:
289+
isna_prev = accum_mask[lab, j]
290+
else:
291+
isna_prev = _treat_as_na(accum[lab, j], is_datetimelike)
292+
267293
if isna_prev:
268-
out[i, j] = na_val
294+
if uses_mask:
295+
result_mask[i, j] = True
296+
# Be deterministic, out was initialized as empty
297+
out[i, j] = 0
298+
else:
299+
out[i, j] = na_val
269300
continue
270301

271302
if isna_entry:
272-
out[i, j] = na_val
303+
304+
if uses_mask:
305+
result_mask[i, j] = True
306+
# Be deterministic, out was initialized as empty
307+
out[i, j] = 0
308+
else:
309+
out[i, j] = na_val
310+
273311
if not skipna:
274-
accum[lab, j] = na_val
312+
if uses_mask:
313+
accum_mask[lab, j] = True
314+
else:
315+
accum[lab, j] = na_val
275316

276317
else:
277318
# For floats, use Kahan summation to reduce floating-point
278319
# error (https://en.wikipedia.org/wiki/Kahan_summation_algorithm)
279-
if numeric_t == float32_t or numeric_t == float64_t:
320+
if int64float_t == float32_t or int64float_t == float64_t:
280321
y = val - compensation[lab, j]
281322
t = accum[lab, j] + y
282323
compensation[lab, j] = t - accum[lab, j] - y
@@ -834,13 +875,6 @@ def group_mean(
834875
out[i, j] = sumx[i, j] / count
835876

836877

837-
ctypedef fused int64float_t:
838-
float32_t
839-
float64_t
840-
int64_t
841-
uint64_t
842-
843-
844878
@cython.wraparound(False)
845879
@cython.boundscheck(False)
846880
def group_ohlc(
@@ -1070,7 +1104,7 @@ cdef numeric_t _get_na_val(numeric_t val, bint is_datetimelike):
10701104
elif numeric_t is int64_t and is_datetimelike:
10711105
na_val = NPY_NAT
10721106
else:
1073-
# Will not be used, but define to avoid uninitialized warning.
1107+
# Used in case of masks
10741108
na_val = 0
10751109
return na_val
10761110

pandas/core/groupby/ops.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def __init__(self, kind: str, how: str, has_dropped_na: bool) -> None:
158158
"rank",
159159
"sum",
160160
"ohlc",
161+
"cumsum",
161162
}
162163

163164
_cython_arity = {"ohlc": 4} # OHLC
@@ -226,7 +227,7 @@ def _get_cython_vals(self, values: np.ndarray) -> np.ndarray:
226227
# result may still include NaN, so we have to cast
227228
values = ensure_float64(values)
228229

229-
elif how in ["sum", "ohlc"]:
230+
elif how in ["sum", "ohlc", "cumsum"]:
230231
# Avoid overflow during group op
231232
if values.dtype.kind == "i":
232233
values = ensure_int64(values)

pandas/tests/groupby/test_groupby.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -2846,12 +2846,15 @@ def test_single_element_list_grouping():
28462846
values, _ = next(iter(df.groupby(["a"])))
28472847

28482848

2849-
def test_groupby_sum_avoid_casting_to_float():
2849+
@pytest.mark.parametrize("func", ["sum", "cumsum"])
2850+
def test_groupby_sum_avoid_casting_to_float(func):
28502851
# GH#37493
28512852
val = 922337203685477580
28522853
df = DataFrame({"a": 1, "b": [val]})
2853-
result = df.groupby("a").sum() - val
2854+
result = getattr(df.groupby("a"), func)() - val
28542855
expected = DataFrame({"b": [0]}, index=Index([1], name="a"))
2856+
if func == "cumsum":
2857+
expected = expected.reset_index(drop=True)
28552858
tm.assert_frame_equal(result, expected)
28562859

28572860

@@ -2868,7 +2871,7 @@ def test_groupby_sum_support_mask(any_numeric_ea_dtype):
28682871

28692872

28702873
@pytest.mark.parametrize("val, dtype", [(111, "int"), (222, "uint")])
2871-
def test_groupby_sum_overflow(val, dtype):
2874+
def test_groupby_overflow(val, dtype):
28722875
# GH#37493
28732876
df = DataFrame({"a": 1, "b": [val, val]}, dtype=f"{dtype}8")
28742877
result = df.groupby("a").sum()
@@ -2878,3 +2881,19 @@ def test_groupby_sum_overflow(val, dtype):
28782881
dtype=f"{dtype}64",
28792882
)
28802883
tm.assert_frame_equal(result, expected)
2884+
2885+
result = df.groupby("a").cumsum()
2886+
expected = DataFrame({"b": [val, val * 2]}, dtype=f"{dtype}64")
2887+
tm.assert_frame_equal(result, expected)
2888+
2889+
2890+
@pytest.mark.parametrize("skipna, val", [(True, 3), (False, pd.NA)])
2891+
def test_groupby_cumsum_mask(any_numeric_ea_dtype, skipna, val):
2892+
# GH#37493
2893+
df = DataFrame({"a": 1, "b": [1, pd.NA, 2]}, dtype=any_numeric_ea_dtype)
2894+
result = df.groupby("a").cumsum(skipna=skipna)
2895+
expected = DataFrame(
2896+
{"b": [1, pd.NA, val]},
2897+
dtype=any_numeric_ea_dtype,
2898+
)
2899+
tm.assert_frame_equal(result, expected)

pandas/tests/groupby/test_libgroupby.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,10 @@ def _check_cython_group_transform_cumulative(pd_op, np_op, dtype):
183183
tm.assert_numpy_array_equal(np_op(data), answer[:, 0], check_dtype=False)
184184

185185

186-
def test_cython_group_transform_cumsum(any_real_numpy_dtype):
186+
@pytest.mark.parametrize("np_dtype", ["int64", "uint64", "float32", "float64"])
187+
def test_cython_group_transform_cumsum(np_dtype):
187188
# see gh-4095
188-
dtype = np.dtype(any_real_numpy_dtype).type
189+
dtype = np.dtype(np_dtype).type
189190
pd_op, np_op = group_cumsum, np.cumsum
190191
_check_cython_group_transform_cumulative(pd_op, np_op, dtype)
191192

0 commit comments

Comments
 (0)