Skip to content

Commit eac1734

Browse files
authored
BUG: GroupBy.cumsum td64, skipna=False (#46216)
1 parent 87803d0 commit eac1734

File tree

3 files changed

+76
-21
lines changed

3 files changed

+76
-21
lines changed

doc/source/whatsnew/v1.5.0.rst

+3
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,9 @@ Groupby/resample/rolling
429429
- Bug in :meth:`.ExponentialMovingWindow.mean` with ``axis=1`` and ``engine='numba'`` when the :class:`DataFrame` has more columns than rows (:issue:`46086`)
430430
- Bug when using ``engine="numba"`` would return the same jitted function when modifying ``engine_kwargs`` (:issue:`46086`)
431431
- Bug in :meth:`.DataFrameGroupby.transform` fails when ``axis=1`` and ``func`` is ``"first"`` or ``"last"`` (:issue:`45986`)
432+
- Bug in :meth:`DataFrameGroupby.cumsum` with ``skipna=False`` giving incorrect results (:issue:`??`)
433+
- Bug in :meth:`.GroupBy.cumsum` with ``timedelta64[ns]`` dtype failing to recognize ``NaT`` as a null value (:issue:`??`)
434+
-
432435

433436
Reshaping
434437
^^^^^^^^^

pandas/_libs/groupby.pyx

+39-21
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ from pandas._libs.algos import (
4343
from pandas._libs.dtypes cimport (
4444
iu_64_floating_obj_t,
4545
iu_64_floating_t,
46+
numeric_object_t,
4647
numeric_t,
4748
)
4849
from pandas._libs.missing cimport checknull
@@ -211,7 +212,7 @@ def group_cumsum(
211212
ndarray[numeric_t, ndim=2] values,
212213
const intp_t[::1] labels,
213214
int ngroups,
214-
is_datetimelike,
215+
bint is_datetimelike,
215216
bint skipna=True,
216217
) -> None:
217218
"""
@@ -238,14 +239,23 @@ def group_cumsum(
238239
"""
239240
cdef:
240241
Py_ssize_t i, j, N, K, size
241-
numeric_t val, y, t
242+
numeric_t val, y, t, na_val
242243
numeric_t[:, ::1] accum, compensation
243244
intp_t lab
245+
bint isna_entry, isna_prev = False
244246

245247
N, K = (<object>values).shape
246248
accum = np.zeros((ngroups, K), dtype=np.asarray(values).dtype)
247249
compensation = np.zeros((ngroups, K), dtype=np.asarray(values).dtype)
248250

251+
if numeric_t == float32_t or numeric_t == float64_t:
252+
na_val = NaN
253+
elif numeric_t is int64_t and is_datetimelike:
254+
na_val = NPY_NAT
255+
else:
256+
# Will not be used, but define to avoid unitialized warning.
257+
na_val = 0
258+
249259
with nogil:
250260
for i in range(N):
251261
lab = labels[i]
@@ -255,22 +265,30 @@ def group_cumsum(
255265
for j in range(K):
256266
val = values[i, j]
257267

258-
# For floats, use Kahan summation to reduce floating-point
259-
# error (https://en.wikipedia.org/wiki/Kahan_summation_algorithm)
260-
if numeric_t == float32_t or numeric_t == float64_t:
261-
if val == val:
268+
isna_entry = _treat_as_na(val, is_datetimelike)
269+
270+
if not skipna:
271+
isna_prev = _treat_as_na(accum[lab, j], is_datetimelike)
272+
if isna_prev:
273+
out[i, j] = na_val
274+
continue
275+
276+
277+
if isna_entry:
278+
out[i, j] = na_val
279+
if not skipna:
280+
accum[lab, j] = na_val
281+
282+
else:
283+
# For floats, use Kahan summation to reduce floating-point
284+
# error (https://en.wikipedia.org/wiki/Kahan_summation_algorithm)
285+
if numeric_t == float32_t or numeric_t == float64_t:
262286
y = val - compensation[lab, j]
263287
t = accum[lab, j] + y
264288
compensation[lab, j] = t - accum[lab, j] - y
265-
accum[lab, j] = t
266-
out[i, j] = t
267289
else:
268-
out[i, j] = NaN
269-
if not skipna:
270-
accum[lab, j] = NaN
271-
break
272-
else:
273-
t = val + accum[lab, j]
290+
t = val + accum[lab, j]
291+
274292
accum[lab, j] = t
275293
out[i, j] = t
276294

@@ -962,19 +980,19 @@ def group_quantile(
962980
# group_nth, group_last, group_rank
963981
# ----------------------------------------------------------------------
964982

965-
cdef inline bint _treat_as_na(iu_64_floating_obj_t val, bint is_datetimelike) nogil:
966-
if iu_64_floating_obj_t is object:
983+
cdef inline bint _treat_as_na(numeric_object_t val, bint is_datetimelike) nogil:
984+
if numeric_object_t is object:
967985
# Should never be used, but we need to avoid the `val != val` below
968986
# or else cython will raise about gil acquisition.
969987
raise NotImplementedError
970988

971-
elif iu_64_floating_obj_t is int64_t:
989+
elif numeric_object_t is int64_t:
972990
return is_datetimelike and val == NPY_NAT
973-
elif iu_64_floating_obj_t is uint64_t:
974-
# There is no NA value for uint64
975-
return False
976-
else:
991+
elif numeric_object_t is float32_t or numeric_object_t is float64_t:
977992
return val != val
993+
else:
994+
# non-datetimelike integer
995+
return False
978996

979997

980998
# TODO(cython3): GH#31710 use memorviews once cython 0.30 is released so we can

pandas/tests/groupby/test_groupby.py

+34
Original file line numberDiff line numberDiff line change
@@ -2468,6 +2468,40 @@ def test_groupby_numerical_stability_cumsum():
24682468
tm.assert_frame_equal(result, expected, check_exact=True)
24692469

24702470

2471+
def test_groupby_cumsum_skipna_false():
2472+
# don't propagate np.nan above the diagonal
2473+
arr = np.random.randn(5, 5)
2474+
df = DataFrame(arr)
2475+
for i in range(5):
2476+
df.iloc[i, i] = np.nan
2477+
2478+
df["A"] = 1
2479+
gb = df.groupby("A")
2480+
2481+
res = gb.cumsum(skipna=False)
2482+
2483+
expected = df[[0, 1, 2, 3, 4]].cumsum(skipna=False)
2484+
tm.assert_frame_equal(res, expected)
2485+
2486+
2487+
def test_groupby_cumsum_timedelta64():
2488+
# don't ignore is_datetimelike in libgroupby.group_cumsum
2489+
dti = date_range("2016-01-01", periods=5)
2490+
ser = Series(dti) - dti[0]
2491+
ser[2] = pd.NaT
2492+
2493+
df = DataFrame({"A": 1, "B": ser})
2494+
gb = df.groupby("A")
2495+
2496+
res = gb.cumsum(numeric_only=False, skipna=True)
2497+
exp = DataFrame({"B": [ser[0], ser[1], pd.NaT, ser[4], ser[4] * 2]})
2498+
tm.assert_frame_equal(res, exp)
2499+
2500+
res = gb.cumsum(numeric_only=False, skipna=False)
2501+
exp = DataFrame({"B": [ser[0], ser[1], pd.NaT, pd.NaT, pd.NaT]})
2502+
tm.assert_frame_equal(res, exp)
2503+
2504+
24712505
def test_groupby_mean_duplicate_index(rand_series_with_duplicate_datetimeindex):
24722506
dups = rand_series_with_duplicate_datetimeindex
24732507
result = dups.groupby(level=0).mean()

0 commit comments

Comments
 (0)