Skip to content

Commit 726994e

Browse files
authored
ENH: Support mask in groupby sum (#48018)
* ENH: Support mask in groupby sum * ENH: Support mask in groupby sum * Fix mypy * Refactor if condition
1 parent 4f256e8 commit 726994e

File tree

5 files changed

+76
-11
lines changed

5 files changed

+76
-11
lines changed

doc/source/whatsnew/v1.5.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -1066,6 +1066,7 @@ Groupby/resample/rolling
10661066
- Bug when using ``engine="numba"`` would return the same jitted function when modifying ``engine_kwargs`` (:issue:`46086`)
10671067
- Bug in :meth:`.DataFrameGroupBy.transform` fails when ``axis=1`` and ``func`` is ``"first"`` or ``"last"`` (:issue:`45986`)
10681068
- Bug in :meth:`DataFrameGroupBy.cumsum` with ``skipna=False`` giving incorrect results (:issue:`46216`)
1069+
- Bug in :meth:`GroupBy.sum` with integer dtypes losing precision (:issue:`37493`)
10691070
- Bug in :meth:`.GroupBy.cumsum` with ``timedelta64[ns]`` dtype failing to recognize ``NaT`` as a null value (:issue:`46216`)
10701071
- Bug in :meth:`.GroupBy.cummin` and :meth:`.GroupBy.cummax` with nullable dtypes incorrectly altering the original data in place (:issue:`46220`)
10711072
- Bug in :meth:`DataFrame.groupby` raising error when ``None`` is in first level of :class:`MultiIndex` (:issue:`47348`)

pandas/_libs/groupby.pyi

+4-2
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,12 @@ def group_any_all(
5151
skipna: bool,
5252
) -> None: ...
5353
def group_sum(
54-
out: np.ndarray, # complexfloating_t[:, ::1]
54+
out: np.ndarray, # complexfloatingintuint_t[:, ::1]
5555
counts: np.ndarray, # int64_t[::1]
56-
values: np.ndarray, # ndarray[complexfloating_t, ndim=2]
56+
values: np.ndarray, # ndarray[complexfloatingintuint_t, ndim=2]
5757
labels: np.ndarray, # const intp_t[:]
58+
mask: np.ndarray | None,
59+
result_mask: np.ndarray | None = ...,
5860
min_count: int = ...,
5961
is_datetimelike: bool = ...,
6062
) -> None: ...

pandas/_libs/groupby.pyx

+44-7
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,15 @@ ctypedef fused mean_t:
513513

514514
ctypedef fused sum_t:
515515
mean_t
516+
int8_t
517+
int16_t
518+
int32_t
519+
int64_t
520+
521+
uint8_t
522+
uint16_t
523+
uint32_t
524+
uint64_t
516525
object
517526

518527

@@ -523,6 +532,8 @@ def group_sum(
523532
int64_t[::1] counts,
524533
ndarray[sum_t, ndim=2] values,
525534
const intp_t[::1] labels,
535+
const uint8_t[:, :] mask,
536+
uint8_t[:, ::1] result_mask=None,
526537
Py_ssize_t min_count=0,
527538
bint is_datetimelike=False,
528539
) -> None:
@@ -535,6 +546,8 @@ def group_sum(
535546
sum_t[:, ::1] sumx, compensation
536547
int64_t[:, ::1] nobs
537548
Py_ssize_t len_values = len(values), len_labels = len(labels)
549+
bint uses_mask = mask is not None
550+
bint isna_entry
538551

539552
if len_values != len_labels:
540553
raise ValueError("len(index) != len(labels)")
@@ -572,7 +585,8 @@ def group_sum(
572585
for i in range(ncounts):
573586
for j in range(K):
574587
if nobs[i, j] < min_count:
575-
out[i, j] = NAN
588+
out[i, j] = None
589+
576590
else:
577591
out[i, j] = sumx[i, j]
578592
else:
@@ -590,11 +604,18 @@ def group_sum(
590604
# With dt64/td64 values, values have been cast to float64
591605
# instead if int64 for group_sum, but the logic
592606
# is otherwise the same as in _treat_as_na
593-
if val == val and not (
594-
sum_t is float64_t
595-
and is_datetimelike
596-
and val == <float64_t>NPY_NAT
597-
):
607+
if uses_mask:
608+
isna_entry = mask[i, j]
609+
elif (sum_t is float32_t or sum_t is float64_t
610+
or sum_t is complex64_t or sum_t is complex64_t):
611+
# avoid warnings because of equality comparison
612+
isna_entry = not val == val
613+
elif sum_t is int64_t and is_datetimelike and val == NPY_NAT:
614+
isna_entry = True
615+
else:
616+
isna_entry = False
617+
618+
if not isna_entry:
598619
nobs[lab, j] += 1
599620
y = val - compensation[lab, j]
600621
t = sumx[lab, j] + y
@@ -604,7 +625,23 @@ def group_sum(
604625
for i in range(ncounts):
605626
for j in range(K):
606627
if nobs[i, j] < min_count:
607-
out[i, j] = NAN
628+
# if we are integer dtype, not is_datetimelike, and
629+
# not uses_mask, then getting here implies that
630+
# counts[i] < min_count, which means we will
631+
# be cast to float64 and masked at the end
632+
# of WrappedCythonOp._call_cython_op. So we can safely
633+
# set a placeholder value in out[i, j].
634+
if uses_mask:
635+
result_mask[i, j] = True
636+
elif (sum_t is float32_t or sum_t is float64_t
637+
or sum_t is complex64_t or sum_t is complex64_t):
638+
out[i, j] = NAN
639+
elif sum_t is int64_t:
640+
out[i, j] = NPY_NAT
641+
else:
642+
# placeholder, see above
643+
out[i, j] = 0
644+
608645
else:
609646
out[i, j] = sumx[i, j]
610647

pandas/core/groupby/ops.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def __init__(self, kind: str, how: str, has_dropped_na: bool) -> None:
155155
"last",
156156
"first",
157157
"rank",
158+
"sum",
158159
}
159160

160161
_cython_arity = {"ohlc": 4} # OHLC
@@ -217,7 +218,7 @@ def _get_cython_vals(self, values: np.ndarray) -> np.ndarray:
217218
values = ensure_float64(values)
218219

219220
elif values.dtype.kind in ["i", "u"]:
220-
if how in ["sum", "var", "prod", "mean", "ohlc"] or (
221+
if how in ["var", "prod", "mean", "ohlc"] or (
221222
self.kind == "transform" and self.has_dropped_na
222223
):
223224
# result may still include NaN, so we have to cast
@@ -578,6 +579,8 @@ def _call_cython_op(
578579
counts=counts,
579580
values=values,
580581
labels=comp_ids,
582+
mask=mask,
583+
result_mask=result_mask,
581584
min_count=min_count,
582585
is_datetimelike=is_datetimelike,
583586
)
@@ -613,7 +616,8 @@ def _call_cython_op(
613616
# need to have the result set to np.nan, which may require casting,
614617
# see GH#40767
615618
if is_integer_dtype(result.dtype) and not is_datetimelike:
616-
cutoff = max(1, min_count)
619+
# Neutral value for sum is 0, so don't fill empty groups with nan
620+
cutoff = max(0 if self.how == "sum" else 1, min_count)
617621
empty_groups = counts < cutoff
618622
if empty_groups.any():
619623
if result_mask is not None and self.uses_mask():

pandas/tests/groupby/test_groupby.py

+21
Original file line numberDiff line numberDiff line change
@@ -2808,3 +2808,24 @@ def test_single_element_list_grouping():
28082808
)
28092809
with tm.assert_produces_warning(FutureWarning, match=msg):
28102810
values, _ = next(iter(df.groupby(["a"])))
2811+
2812+
2813+
def test_groupby_sum_avoid_casting_to_float():
2814+
# GH#37493
2815+
val = 922337203685477580
2816+
df = DataFrame({"a": 1, "b": [val]})
2817+
result = df.groupby("a").sum() - val
2818+
expected = DataFrame({"b": [0]}, index=Index([1], name="a"))
2819+
tm.assert_frame_equal(result, expected)
2820+
2821+
2822+
def test_groupby_sum_support_mask(any_numeric_ea_dtype):
2823+
# GH#37493
2824+
df = DataFrame({"a": 1, "b": [1, 2, pd.NA]}, dtype=any_numeric_ea_dtype)
2825+
result = df.groupby("a").sum()
2826+
expected = DataFrame(
2827+
{"b": [3]},
2828+
index=Index([1], name="a", dtype=any_numeric_ea_dtype),
2829+
dtype=any_numeric_ea_dtype,
2830+
)
2831+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)