Skip to content

Commit 4de1d6b

Browse files
committed
pandas-dev#31422 GroupBy.sum() returns 0 for missing categories when grouping by multiple Categoricals. Updates to tests to reflect this expected output
1 parent b65467a commit 4de1d6b

File tree

3 files changed

+41
-34
lines changed

3 files changed

+41
-34
lines changed

pandas/core/groupby/generic.py

+18-7
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,9 @@ def _wrap_series_output(
363363
return result
364364

365365
def _wrap_aggregated_output(
366-
self, output: Mapping[base.OutputKey, Union[Series, np.ndarray]]
366+
self,
367+
output: Mapping[base.OutputKey, Union[Series, np.ndarray]],
368+
fill_value: Scalar = np.NaN,
367369
) -> Union[Series, DataFrame]:
368370
"""
369371
Wraps the output of a SeriesGroupBy aggregation into the expected result.
@@ -385,7 +387,7 @@ def _wrap_aggregated_output(
385387
result = self._wrap_series_output(
386388
output=output, index=self.grouper.result_index
387389
)
388-
return self._reindex_output(result)
390+
return self._reindex_output(result, fill_value)
389391

390392
def _wrap_transformed_output(
391393
self, output: Mapping[base.OutputKey, Union[Series, np.ndarray]]
@@ -415,7 +417,11 @@ def _wrap_transformed_output(
415417
return result
416418

417419
def _wrap_applied_output(
418-
self, keys: Index, values: Optional[List[Any]], not_indexed_same: bool = False
420+
self,
421+
keys: Index,
422+
values: Optional[List[Any]],
423+
not_indexed_same: bool = False,
424+
fill_value: Scalar = np.NaN,
419425
) -> FrameOrSeriesUnion:
420426
"""
421427
Wrap the output of SeriesGroupBy.apply into the expected result.
@@ -465,7 +471,7 @@ def _get_index() -> Index:
465471
result = self.obj._constructor(
466472
data=values, index=_get_index(), name=self._selection_name
467473
)
468-
return self._reindex_output(result)
474+
return self._reindex_output(result, fill_value)
469475

470476
def _aggregate_named(self, func, *args, **kwargs):
471477
result = {}
@@ -1029,7 +1035,10 @@ def _cython_agg_general(
10291035
agg_blocks, agg_items = self._cython_agg_blocks(
10301036
how, alt=alt, numeric_only=numeric_only, min_count=min_count
10311037
)
1032-
return self._wrap_agged_blocks(agg_blocks, items=agg_items)
1038+
fill_value = self._cython_func_fill_values.get(alt, np.NaN)
1039+
return self._wrap_agged_blocks(
1040+
agg_blocks, items=agg_items, fill_value=fill_value
1041+
)
10331042

10341043
def _cython_agg_blocks(
10351044
self, how: str, alt=None, numeric_only: bool = True, min_count: int = -1
@@ -1219,7 +1228,9 @@ def _aggregate_item_by_item(self, func, *args, **kwargs) -> DataFrame:
12191228

12201229
return self.obj._constructor(result, columns=result_columns)
12211230

1222-
def _wrap_applied_output(self, keys, values, not_indexed_same=False):
1231+
def _wrap_applied_output(
1232+
self, keys, values, not_indexed_same=False, fill_value: Scalar = np.NaN
1233+
):
12231234
if len(keys) == 0:
12241235
return self.obj._constructor(index=keys)
12251236

@@ -1380,7 +1391,7 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
13801391
if not self.as_index:
13811392
self._insert_inaxis_grouper_inplace(result)
13821393

1383-
return self._reindex_output(result)
1394+
return self._reindex_output(result, fill_value)
13841395

13851396
# values are not series or array-like but scalars
13861397
else:

pandas/core/groupby/groupby.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -888,8 +888,12 @@ def _python_apply_general(
888888
"""
889889
keys, values, mutated = self.grouper.apply(f, data, self.axis)
890890

891+
fill_value = self._cython_func_fill_values.get(f, np.NaN)
891892
return self._wrap_applied_output(
892-
keys, values, not_indexed_same=mutated or self.mutated
893+
keys,
894+
values,
895+
not_indexed_same=mutated or self.mutated,
896+
fill_value=fill_value,
893897
)
894898

895899
def _iterate_slices(self) -> Iterable[Series]:
@@ -1010,6 +1014,8 @@ def _agg_general(
10101014
result = self.aggregate(lambda x: npfunc(x, axis=self.axis))
10111015
return result
10121016

1017+
_cython_func_fill_values = {np.sum: 0}
1018+
10131019
def _cython_agg_general(
10141020
self, how: str, alt=None, numeric_only: bool = True, min_count: int = -1
10151021
):
@@ -1045,7 +1051,9 @@ def _cython_agg_general(
10451051
if len(output) == 0:
10461052
raise DataError("No numeric types to aggregate")
10471053

1048-
return self._wrap_aggregated_output(output)
1054+
fill_value = self._cython_func_fill_values.get(alt, np.NaN)
1055+
1056+
return self._wrap_aggregated_output(output, fill_value)
10491057

10501058
def _python_agg_general(
10511059
self, func, *args, engine="cython", engine_kwargs=None, **kwargs

pandas/tests/groupby/test_categorical.py

+13-25
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import pandas._testing as tm
2020

2121

22-
def cartesian_product_for_groupers(result, args, names):
22+
def cartesian_product_for_groupers(result, args, names, fill_value=np.NaN):
2323
""" Reindex to a cartesian production for the groupers,
2424
preserving the nature (Categorical) of each grouper
2525
"""
@@ -33,7 +33,7 @@ def f(a):
3333
return a
3434

3535
index = MultiIndex.from_product(map(f, args), names=names)
36-
return result.reindex(index).sort_index()
36+
return result.reindex(index, fill_value=fill_value).sort_index()
3737

3838

3939
_results_for_groupbys_with_missing_categories = dict(
@@ -309,7 +309,7 @@ def test_observed(observed):
309309
result = gb.sum()
310310
if not observed:
311311
expected = cartesian_product_for_groupers(
312-
expected, [cat1, cat2, ["foo", "bar"]], list("ABC")
312+
expected, [cat1, cat2, ["foo", "bar"]], list("ABC"), fill_value=0
313313
)
314314

315315
tm.assert_frame_equal(result, expected)
@@ -319,7 +319,9 @@ def test_observed(observed):
319319
expected = DataFrame({"values": [1, 2, 3, 4]}, index=exp_index)
320320
result = gb.sum()
321321
if not observed:
322-
expected = cartesian_product_for_groupers(expected, [cat1, cat2], list("AB"))
322+
expected = cartesian_product_for_groupers(
323+
expected, [cat1, cat2], list("AB"), fill_value=0
324+
)
323325

324326
tm.assert_frame_equal(result, expected)
325327

@@ -1188,9 +1190,10 @@ def test_seriesgroupby_observed_false_or_none(df_cat, observed, operation):
11881190
names=["A", "B"],
11891191
).sortlevel()
11901192

1191-
expected = Series(data=[2, 4, np.nan, 1, np.nan, 3], index=index, name="C")
1193+
expected = Series(data=[2, 4, 0, 1, 0, 3], index=index, name="C")
11921194
grouped = df_cat.groupby(["A", "B"], observed=observed)["C"]
11931195
result = getattr(grouped, operation)(sum)
1196+
11941197
tm.assert_series_equal(result, expected)
11951198

11961199

@@ -1340,15 +1343,6 @@ def test_series_groupby_on_2_categoricals_unobserved_zeroes_or_nans(
13401343
)
13411344
request.node.add_marker(mark)
13421345

1343-
if reduction_func == "sum": # GH 31422
1344-
mark = pytest.mark.xfail(
1345-
reason=(
1346-
"sum should return 0 but currently returns NaN. "
1347-
"This is a known bug. See GH 31422."
1348-
)
1349-
)
1350-
request.node.add_marker(mark)
1351-
13521346
df = pd.DataFrame(
13531347
{
13541348
"cat_1": pd.Categorical(list("AABB"), categories=list("ABC")),
@@ -1369,8 +1363,11 @@ def test_series_groupby_on_2_categoricals_unobserved_zeroes_or_nans(
13691363
val = result.loc[idx]
13701364
assert (pd.isna(zero_or_nan) and pd.isna(val)) or (val == zero_or_nan)
13711365

1372-
# If we expect unobserved values to be zero, we also expect the dtype to be int
1373-
if zero_or_nan == 0:
1366+
# If we expect unobserved values to be zero, we also expect the dtype to be int.
1367+
# Except for .sum(). If the observed categories sum to dtype=float (i.e. their
1368+
# sums have decimals), then the zeros for the missing categories should also be
1369+
# floats.
1370+
if zero_or_nan == 0 and reduction_func != "sum":
13741371
assert np.issubdtype(result.dtype, np.integer)
13751372

13761373

@@ -1412,15 +1409,6 @@ def test_dataframe_groupby_on_2_categoricals_when_observed_is_false(
14121409
if reduction_func == "ngroup":
14131410
pytest.skip("ngroup does not return the Categories on the index")
14141411

1415-
if reduction_func == "sum": # GH 31422
1416-
mark = pytest.mark.xfail(
1417-
reason=(
1418-
"sum should return 0 but currently returns NaN. "
1419-
"This is a known bug. See GH 31422."
1420-
)
1421-
)
1422-
request.node.add_marker(mark)
1423-
14241412
df = pd.DataFrame(
14251413
{
14261414
"cat_1": pd.Categorical(list("AABB"), categories=list("ABC")),

0 commit comments

Comments
 (0)