Skip to content

ENH: Support mask in groupby sum #48018

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.5.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1047,6 +1047,7 @@ Groupby/resample/rolling
- Bug when using ``engine="numba"`` would return the same jitted function when modifying ``engine_kwargs`` (:issue:`46086`)
- Bug in :meth:`.DataFrameGroupBy.transform` fails when ``axis=1`` and ``func`` is ``"first"`` or ``"last"`` (:issue:`45986`)
- Bug in :meth:`DataFrameGroupBy.cumsum` with ``skipna=False`` giving incorrect results (:issue:`46216`)
- Bug in :meth:`GroupBy.sum` with integer dtypes losing precision (:issue:`37493`)
- Bug in :meth:`.GroupBy.cumsum` with ``timedelta64[ns]`` dtype failing to recognize ``NaT`` as a null value (:issue:`46216`)
- Bug in :meth:`.GroupBy.cummin` and :meth:`.GroupBy.cummax` with nullable dtypes incorrectly altering the original data in place (:issue:`46220`)
- Bug in :meth:`DataFrame.groupby` raising error when ``None`` is in first level of :class:`MultiIndex` (:issue:`47348`)
Expand Down
6 changes: 4 additions & 2 deletions pandas/_libs/groupby.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,12 @@ def group_any_all(
skipna: bool,
) -> None: ...
def group_sum(
out: np.ndarray, # complexfloating_t[:, ::1]
out: np.ndarray, # complexfloatingintuint_t[:, ::1]
counts: np.ndarray, # int64_t[::1]
values: np.ndarray, # ndarray[complexfloating_t, ndim=2]
values: np.ndarray, # ndarray[complexfloatingintuint_t, ndim=2]
labels: np.ndarray, # const intp_t[:]
mask: np.ndarray | None,
result_mask: np.ndarray | None = ...,
min_count: int = ...,
is_datetimelike: bool = ...,
) -> None: ...
Expand Down
51 changes: 44 additions & 7 deletions pandas/_libs/groupby.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,15 @@ ctypedef fused mean_t:

ctypedef fused sum_t:
mean_t
int8_t
int16_t
int32_t
int64_t

uint8_t
uint16_t
uint32_t
uint64_t
object


Expand All @@ -523,6 +532,8 @@ def group_sum(
int64_t[::1] counts,
ndarray[sum_t, ndim=2] values,
const intp_t[::1] labels,
const uint8_t[:, :] mask,
uint8_t[:, ::1] result_mask=None,
Py_ssize_t min_count=0,
bint is_datetimelike=False,
) -> None:
Expand All @@ -535,6 +546,8 @@ def group_sum(
sum_t[:, ::1] sumx, compensation
int64_t[:, ::1] nobs
Py_ssize_t len_values = len(values), len_labels = len(labels)
bint uses_mask = mask is not None
bint isna_entry

if len_values != len_labels:
raise ValueError("len(index) != len(labels)")
Expand Down Expand Up @@ -572,7 +585,8 @@ def group_sum(
for i in range(ncounts):
for j in range(K):
if nobs[i, j] < min_count:
out[i, j] = NAN
out[i, j] = None

else:
out[i, j] = sumx[i, j]
else:
Expand All @@ -590,11 +604,18 @@ def group_sum(
# With dt64/td64 values, values have been cast to float64
# instead if int64 for group_sum, but the logic
# is otherwise the same as in _treat_as_na
if val == val and not (
sum_t is float64_t
and is_datetimelike
and val == <float64_t>NPY_NAT
):
if uses_mask:
isna_entry = mask[i, j]
elif (sum_t is float32_t or sum_t is float64_t
or sum_t is complex64_t or sum_t is complex64_t):
# avoid warnings because of equality comparison
isna_entry = not val == val
elif sum_t is int64_t and is_datetimelike and val == NPY_NAT:
isna_entry = True
else:
isna_entry = False

if not isna_entry:
nobs[lab, j] += 1
y = val - compensation[lab, j]
t = sumx[lab, j] + y
Expand All @@ -604,7 +625,23 @@ def group_sum(
for i in range(ncounts):
for j in range(K):
if nobs[i, j] < min_count:
out[i, j] = NAN
# if we are integer dtype, not is_datetimelike, and
# not uses_mask, then getting here implies that
# counts[i] < min_count, which means we will
# be cast to float64 and masked at the end
# of WrappedCythonOp._call_cython_op. So we can safely
# set a placeholder value in out[i, j].
if uses_mask:
result_mask[i, j] = True
elif (sum_t is float32_t or sum_t is float64_t
or sum_t is complex64_t or sum_t is complex64_t):
out[i, j] = NAN
elif sum_t is int64_t:
out[i, j] = NPY_NAT
else:
# placeholder, see above
out[i, j] = 0

else:
out[i, j] = sumx[i, j]

Expand Down
8 changes: 6 additions & 2 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def __init__(self, kind: str, how: str, has_dropped_na: bool) -> None:
"last",
"first",
"rank",
"sum",
}

_cython_arity = {"ohlc": 4} # OHLC
Expand Down Expand Up @@ -217,7 +218,7 @@ def _get_cython_vals(self, values: np.ndarray) -> np.ndarray:
values = ensure_float64(values)

elif values.dtype.kind in ["i", "u"]:
if how in ["sum", "var", "prod", "mean", "ohlc"] or (
if how in ["var", "prod", "mean", "ohlc"] or (
self.kind == "transform" and self.has_dropped_na
):
# result may still include NaN, so we have to cast
Expand Down Expand Up @@ -578,6 +579,8 @@ def _call_cython_op(
counts=counts,
values=values,
labels=comp_ids,
mask=mask,
result_mask=result_mask,
min_count=min_count,
is_datetimelike=is_datetimelike,
)
Expand Down Expand Up @@ -613,7 +616,8 @@ def _call_cython_op(
# need to have the result set to np.nan, which may require casting,
# see GH#40767
if is_integer_dtype(result.dtype) and not is_datetimelike:
cutoff = max(1, min_count)
# Neutral value for sum is 0, so don't fill empty groups with nan
cutoff = max(0 if self.how == "sum" else 1, min_count)
empty_groups = counts < cutoff
if empty_groups.any():
if result_mask is not None and self.uses_mask():
Expand Down
21 changes: 21 additions & 0 deletions pandas/tests/groupby/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2808,3 +2808,24 @@ def test_single_element_list_grouping():
)
with tm.assert_produces_warning(FutureWarning, match=msg):
values, _ = next(iter(df.groupby(["a"])))


def test_groupby_sum_avoid_casting_to_float():
# GH#37493
val = 922337203685477580
df = DataFrame({"a": 1, "b": [val]})
result = df.groupby("a").sum() - val
expected = DataFrame({"b": [0]}, index=Index([1], name="a"))
tm.assert_frame_equal(result, expected)


def test_groupby_sum_support_mask(any_numeric_ea_dtype):
# GH#37493
df = DataFrame({"a": 1, "b": [1, 2, pd.NA]}, dtype=any_numeric_ea_dtype)
result = df.groupby("a").sum()
expected = DataFrame(
{"b": [3]},
index=Index([1], name="a", dtype=any_numeric_ea_dtype),
dtype=any_numeric_ea_dtype,
)
tm.assert_frame_equal(result, expected)