Skip to content

Commit 3ca5773

Browse files
Backport PR pandas-dev#48027 on branch 1.5.x (ENH: Support masks in groupby prod) (pandas-dev#48302)
Backport PR pandas-dev#48027: ENH: Support masks in groupby prod Co-authored-by: Patrick Hoefler <[email protected]>
1 parent 97cf8e2 commit 3ca5773

File tree

5 files changed

+61
-21
lines changed

5 files changed

+61
-21
lines changed

doc/source/whatsnew/v1.5.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -1165,7 +1165,7 @@ Groupby/resample/rolling
11651165
- Bug when using ``engine="numba"`` would return the same jitted function when modifying ``engine_kwargs`` (:issue:`46086`)
11661166
- Bug in :meth:`.DataFrameGroupBy.transform` fails when ``axis=1`` and ``func`` is ``"first"`` or ``"last"`` (:issue:`45986`)
11671167
- Bug in :meth:`DataFrameGroupBy.cumsum` with ``skipna=False`` giving incorrect results (:issue:`46216`)
1168-
- Bug in :meth:`.GroupBy.sum` and :meth:`.GroupBy.cumsum` with integer dtypes losing precision (:issue:`37493`)
1168+
- Bug in :meth:`.GroupBy.sum`, :meth:`.GroupBy.prod` and :meth:`.GroupBy.cumsum` with integer dtypes losing precision (:issue:`37493`)
11691169
- Bug in :meth:`.GroupBy.cumsum` with ``timedelta64[ns]`` dtype failing to recognize ``NaT`` as a null value (:issue:`46216`)
11701170
- Bug in :meth:`.GroupBy.cumsum` with integer dtypes causing overflows when sum was bigger than maximum of dtype (:issue:`37493`)
11711171
- Bug in :meth:`.GroupBy.cummin` and :meth:`.GroupBy.cummax` with nullable dtypes incorrectly altering the original data in place (:issue:`46220`)

pandas/_libs/groupby.pyi

+4-2
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,12 @@ def group_sum(
6363
is_datetimelike: bool = ...,
6464
) -> None: ...
6565
def group_prod(
66-
out: np.ndarray, # floating[:, ::1]
66+
out: np.ndarray, # int64float_t[:, ::1]
6767
counts: np.ndarray, # int64_t[::1]
68-
values: np.ndarray, # ndarray[floating, ndim=2]
68+
values: np.ndarray, # ndarray[int64float_t, ndim=2]
6969
labels: np.ndarray, # const intp_t[:]
70+
mask: np.ndarray | None,
71+
result_mask: np.ndarray | None = ...,
7072
min_count: int = ...,
7173
) -> None: ...
7274
def group_var(

pandas/_libs/groupby.pyx

+27-7
Original file line numberDiff line numberDiff line change
@@ -682,21 +682,24 @@ def group_sum(
682682
@cython.wraparound(False)
683683
@cython.boundscheck(False)
684684
def group_prod(
685-
floating[:, ::1] out,
685+
int64float_t[:, ::1] out,
686686
int64_t[::1] counts,
687-
ndarray[floating, ndim=2] values,
687+
ndarray[int64float_t, ndim=2] values,
688688
const intp_t[::1] labels,
689+
const uint8_t[:, ::1] mask,
690+
uint8_t[:, ::1] result_mask=None,
689691
Py_ssize_t min_count=0,
690692
) -> None:
691693
"""
692694
Only aggregates on axis=0
693695
"""
694696
cdef:
695697
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
696-
floating val, count
697-
floating[:, ::1] prodx
698+
int64float_t val, count
699+
int64float_t[:, ::1] prodx
698700
int64_t[:, ::1] nobs
699701
Py_ssize_t len_values = len(values), len_labels = len(labels)
702+
bint isna_entry, uses_mask = mask is not None
700703

701704
if len_values != len_labels:
702705
raise ValueError("len(index) != len(labels)")
@@ -716,15 +719,32 @@ def group_prod(
716719
for j in range(K):
717720
val = values[i, j]
718721

719-
# not nan
720-
if val == val:
722+
if uses_mask:
723+
isna_entry = mask[i, j]
724+
elif int64float_t is float32_t or int64float_t is float64_t:
725+
isna_entry = not val == val
726+
else:
727+
isna_entry = False
728+
729+
if not isna_entry:
721730
nobs[lab, j] += 1
722731
prodx[lab, j] *= val
723732

724733
for i in range(ncounts):
725734
for j in range(K):
726735
if nobs[i, j] < min_count:
727-
out[i, j] = NAN
736+
737+
# else case is not possible
738+
if uses_mask:
739+
result_mask[i, j] = True
740+
# Be deterministic, out was initialized as empty
741+
out[i, j] = 0
742+
elif int64float_t is float32_t or int64float_t is float64_t:
743+
out[i, j] = NAN
744+
else:
745+
# we only get here when < mincount which gets handled later
746+
pass
747+
728748
else:
729749
out[i, j] = prodx[i, j]
730750

pandas/core/groupby/ops.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def __init__(self, kind: str, how: str, has_dropped_na: bool) -> None:
159159
"sum",
160160
"ohlc",
161161
"cumsum",
162+
"prod",
162163
}
163164

164165
_cython_arity = {"ohlc": 4} # OHLC
@@ -221,13 +222,13 @@ def _get_cython_vals(self, values: np.ndarray) -> np.ndarray:
221222
values = ensure_float64(values)
222223

223224
elif values.dtype.kind in ["i", "u"]:
224-
if how in ["var", "prod", "mean"] or (
225+
if how in ["var", "mean"] or (
225226
self.kind == "transform" and self.has_dropped_na
226227
):
227228
# result may still include NaN, so we have to cast
228229
values = ensure_float64(values)
229230

230-
elif how in ["sum", "ohlc", "cumsum"]:
231+
elif how in ["sum", "ohlc", "prod", "cumsum"]:
231232
# Avoid overflow during group op
232233
if values.dtype.kind == "i":
233234
values = ensure_int64(values)
@@ -597,8 +598,16 @@ def _call_cython_op(
597598
min_count=min_count,
598599
is_datetimelike=is_datetimelike,
599600
)
600-
elif self.how == "ohlc":
601-
func(result, counts, values, comp_ids, min_count, mask, result_mask)
601+
elif self.how in ["ohlc", "prod"]:
602+
func(
603+
result,
604+
counts,
605+
values,
606+
comp_ids,
607+
min_count=min_count,
608+
mask=mask,
609+
result_mask=result_mask,
610+
)
602611
else:
603612
func(result, counts, values, comp_ids, min_count, **kwargs)
604613
else:
@@ -631,8 +640,8 @@ def _call_cython_op(
631640
# need to have the result set to np.nan, which may require casting,
632641
# see GH#40767
633642
if is_integer_dtype(result.dtype) and not is_datetimelike:
634-
# Neutral value for sum is 0, so don't fill empty groups with nan
635-
cutoff = max(0 if self.how == "sum" else 1, min_count)
643+
# if the op keeps the int dtypes, we have to use 0
644+
cutoff = max(0 if self.how in ["sum", "prod"] else 1, min_count)
636645
empty_groups = counts < cutoff
637646
if empty_groups.any():
638647
if result_mask is not None and self.uses_mask():

pandas/tests/groupby/test_groupby.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -2847,8 +2847,8 @@ def test_single_element_list_grouping():
28472847
values, _ = next(iter(df.groupby(["a"])))
28482848

28492849

2850-
@pytest.mark.parametrize("func", ["sum", "cumsum"])
2851-
def test_groupby_sum_avoid_casting_to_float(func):
2850+
@pytest.mark.parametrize("func", ["sum", "cumsum", "prod"])
2851+
def test_groupby_avoid_casting_to_float(func):
28522852
# GH#37493
28532853
val = 922337203685477580
28542854
df = DataFrame({"a": 1, "b": [val]})
@@ -2859,12 +2859,13 @@ def test_groupby_sum_avoid_casting_to_float(func):
28592859
tm.assert_frame_equal(result, expected)
28602860

28612861

2862-
def test_groupby_sum_support_mask(any_numeric_ea_dtype):
2862+
@pytest.mark.parametrize("func, val", [("sum", 3), ("prod", 2)])
2863+
def test_groupby_sum_support_mask(any_numeric_ea_dtype, func, val):
28632864
# GH#37493
28642865
df = DataFrame({"a": 1, "b": [1, 2, pd.NA]}, dtype=any_numeric_ea_dtype)
2865-
result = df.groupby("a").sum()
2866+
result = getattr(df.groupby("a"), func)()
28662867
expected = DataFrame(
2867-
{"b": [3]},
2868+
{"b": [val]},
28682869
index=Index([1], name="a", dtype=any_numeric_ea_dtype),
28692870
dtype=any_numeric_ea_dtype,
28702871
)
@@ -2887,6 +2888,14 @@ def test_groupby_overflow(val, dtype):
28872888
expected = DataFrame({"b": [val, val * 2]}, dtype=f"{dtype}64")
28882889
tm.assert_frame_equal(result, expected)
28892890

2891+
result = df.groupby("a").prod()
2892+
expected = DataFrame(
2893+
{"b": [val * val]},
2894+
index=Index([1], name="a", dtype=f"{dtype}64"),
2895+
dtype=f"{dtype}64",
2896+
)
2897+
tm.assert_frame_equal(result, expected)
2898+
28902899

28912900
@pytest.mark.parametrize("skipna, val", [(True, 3), (False, pd.NA)])
28922901
def test_groupby_cumsum_mask(any_numeric_ea_dtype, skipna, val):

0 commit comments

Comments
 (0)