Skip to content

Commit 5c30d1c

Browse files
dsaxtonsimonjayhawkins
authored andcommitted
BUG: Fix min_count issue for groupby.sum (pandas-dev#32914)
* Add test * Check for null * Release note * Update and comment * Update test * Hack * Try a different casting * No pd * Only for add * Undo * Release note * Fix * Space * maybe_cast_result * float -> Int * Less if Co-authored-by: Simon Hawkins <[email protected]>
1 parent c06a21f commit 5c30d1c

File tree

3 files changed

+22
-11
lines changed

3 files changed

+22
-11
lines changed

doc/source/whatsnew/v1.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,7 @@ Groupby/resample/rolling
752752
- Bug in :meth:`DataFrameGroupby.transform` produces incorrect result with transformation functions (:issue:`30918`)
753753
- Bug in :meth:`GroupBy.count` causes segmentation fault when grouped-by column contains NaNs (:issue:`32841`)
754754
- Bug in :meth:`DataFrame.groupby` and :meth:`Series.groupby` produces inconsistent type when aggregating Boolean series (:issue:`32894`)
755+
- Bug in :meth:`DataFrameGroupBy.sum` and :meth:`SeriesGroupBy.sum` where a large negative number would be returned when the number of non-null values was below ``min_count`` for nullable integer dtypes (:issue:`32861`)
755756
- Bug in :meth:`SeriesGroupBy.quantile` raising on nullable integers (:issue:`33136`)
756757
- Bug in :meth:`SeriesGroupBy.first`, :meth:`SeriesGroupBy.last`, :meth:`SeriesGroupBy.min`, and :meth:`SeriesGroupBy.max` returning floats when applied to nullable Booleans (:issue:`33071`)
757758
- Bug in :meth:`DataFrameGroupBy.agg` with dictionary input losing ``ExtensionArray`` dtypes (:issue:`32194`)

pandas/core/groupby/ops.py

+4-11
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from pandas.errors import AbstractMethodError
1919
from pandas.util._decorators import cache_readonly
2020

21+
from pandas.core.dtypes.cast import maybe_cast_result
2122
from pandas.core.dtypes.common import (
2223
ensure_float64,
2324
ensure_int64,
@@ -548,17 +549,6 @@ def _cython_operation(
548549
if mask.any():
549550
result = result.astype("float64")
550551
result[mask] = np.nan
551-
elif (
552-
how == "add"
553-
and is_integer_dtype(orig_values.dtype)
554-
and is_extension_array_dtype(orig_values.dtype)
555-
):
556-
# We need this to ensure that Series[Int64Dtype].resample().sum()
557-
# remains int64 dtype.
558-
# Two options for avoiding this special case
559-
# 1. mask-aware ops and avoid casting to float with NaN above
560-
# 2. specify the result dtype when calling this method
561-
result = result.astype("int64")
562552

563553
if kind == "aggregate" and self._filter_empty_groups and not counts.all():
564554
assert result.ndim != 2
@@ -582,6 +572,9 @@ def _cython_operation(
582572
elif is_datetimelike and kind == "aggregate":
583573
result = result.astype(orig_values.dtype)
584574

575+
if is_extension_array_dtype(orig_values.dtype):
576+
result = maybe_cast_result(result=result, obj=orig_values, how=how)
577+
585578
return result, names
586579

587580
def aggregate(

pandas/tests/groupby/test_function.py

+17
Original file line numberDiff line numberDiff line change
@@ -1014,3 +1014,20 @@ def test_apply_to_nullable_integer_returns_float(values, function):
10141014
result = groups.agg([function])
10151015
expected.columns = MultiIndex.from_tuples([("b", function)])
10161016
tm.assert_frame_equal(result, expected)
1017+
1018+
1019+
def test_groupby_sum_below_mincount_nullable_integer():
1020+
# https://github.com/pandas-dev/pandas/issues/32861
1021+
df = pd.DataFrame({"a": [0, 1, 2], "b": [0, 1, 2], "c": [0, 1, 2]}, dtype="Int64")
1022+
grouped = df.groupby("a")
1023+
idx = pd.Index([0, 1, 2], dtype=object, name="a")
1024+
1025+
result = grouped["b"].sum(min_count=2)
1026+
expected = pd.Series([pd.NA] * 3, dtype="Int64", index=idx, name="b")
1027+
tm.assert_series_equal(result, expected)
1028+
1029+
result = grouped.sum(min_count=2)
1030+
expected = pd.DataFrame(
1031+
{"b": [pd.NA] * 3, "c": [pd.NA] * 3}, dtype="Int64", index=idx
1032+
)
1033+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)