Skip to content

Commit 181f593

Browse files
API: add numeric_only support to groupby agg
1 parent 6f39c4f commit 181f593

File tree

4 files changed

+68
-25
lines changed

4 files changed

+68
-25
lines changed

pandas/core/apply.py

+9-16
Original file line numberDiff line numberDiff line change
@@ -303,13 +303,14 @@ def agg_list_like(self) -> DataFrame | Series:
303303
-------
304304
Result of aggregation.
305305
"""
306-
return self.agg_or_apply_list_like(op_name="agg")
306+
kwargs = self.kwargs
307+
return self.agg_or_apply_list_like(op_name="agg", **kwargs)
307308

308309
def compute_list_like(
309310
self,
310311
op_name: Literal["agg", "apply"],
311312
selected_obj: Series | DataFrame,
312-
kwargs: dict[str, Any],
313+
**kwargs: dict[str, Any],
313314
) -> tuple[list[Hashable] | Index, list[Any]]:
314315
"""
315316
Compute agg/apply results for like-like input.
@@ -333,7 +334,6 @@ def compute_list_like(
333334
"""
334335
func = cast(list[AggFuncTypeBase], self.func)
335336
obj = self.obj
336-
337337
results = []
338338
keys = []
339339

@@ -348,7 +348,6 @@ def compute_list_like(
348348
)
349349
new_res = getattr(colg, op_name)(a, *args, **kwargs)
350350
results.append(new_res)
351-
352351
# make sure we find a good name
353352
name = com.get_callable_name(a) or a
354353
keys.append(name)
@@ -691,10 +690,9 @@ def agg_axis(self) -> Index:
691690
return self.obj._get_agg_axis(self.axis)
692691

693692
def agg_or_apply_list_like(
694-
self, op_name: Literal["agg", "apply"]
693+
self, op_name: Literal["agg", "apply"], numeric_only=False, **kwargs
695694
) -> DataFrame | Series:
696695
obj = self.obj
697-
kwargs = self.kwargs
698696

699697
if op_name == "apply":
700698
if isinstance(self, FrameApply):
@@ -708,7 +706,6 @@ def agg_or_apply_list_like(
708706

709707
if getattr(obj, "axis", 0) == 1:
710708
raise NotImplementedError("axis other than 0 is not supported")
711-
712709
keys, results = self.compute_list_like(op_name, obj, kwargs)
713710
result = self.wrap_results_list_like(keys, results)
714711
return result
@@ -1485,28 +1482,24 @@ def transform(self):
14851482
raise NotImplementedError
14861483

14871484
def agg_or_apply_list_like(
1488-
self, op_name: Literal["agg", "apply"]
1485+
self, op_name: Literal["agg", "apply"], numeric_only=False, **kwargs
14891486
) -> DataFrame | Series:
14901487
obj = self.obj
1491-
kwargs = self.kwargs
1488+
14921489
if op_name == "apply":
14931490
kwargs = {**kwargs, "by_row": False}
14941491

14951492
if getattr(obj, "axis", 0) == 1:
14961493
raise NotImplementedError("axis other than 0 is not supported")
14971494

1498-
if obj._selected_obj.ndim == 1:
1499-
# For SeriesGroupBy this matches _obj_with_exclusions
1500-
selected_obj = obj._selected_obj
1501-
else:
1502-
selected_obj = obj._obj_with_exclusions
1503-
1495+
mgr = obj._get_data_to_aggregate(numeric_only=numeric_only)
1496+
selected_obj = obj._wrap_agged_manager(mgr)
15041497
# Only set as_index=True on groupby objects, not Window or Resample
15051498
# that inherit from this class.
15061499
with com.temp_setattr(
15071500
obj, "as_index", True, condition=hasattr(obj, "as_index")
15081501
):
1509-
keys, results = self.compute_list_like(op_name, selected_obj, kwargs)
1502+
keys, results = self.compute_list_like(op_name, selected_obj, **kwargs)
15101503
result = self.wrap_results_list_like(keys, results)
15111504
return result
15121505

pandas/core/groupby/generic.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -1553,7 +1553,7 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
15531553

15541554
else:
15551555
# try to treat as if we are passing a list
1556-
gba = GroupByApply(self, [func], args=(), kwargs={})
1556+
gba = GroupByApply(self, [func], args=args, kwargs=kwargs)
15571557
try:
15581558
result = gba.agg()
15591559

@@ -1582,15 +1582,15 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
15821582

15831583
agg = aggregate
15841584

1585-
def _python_agg_general(self, func, *args, **kwargs):
1585+
def _python_agg_general(self, func, *args, numeric_only=False, **kwargs):
15861586
f = lambda x: func(x, *args, **kwargs)
15871587

15881588
if self.ngroups == 0:
15891589
# e.g. test_evaluate_with_empty_groups different path gets different
15901590
# result dtype in empty case.
15911591
return self._python_apply_general(f, self._selected_obj, is_agg=True)
1592-
1593-
obj = self._obj_with_exclusions
1592+
mgr = self._get_data_to_aggregate(numeric_only=numeric_only)
1593+
obj = self._wrap_agged_manager(mgr)
15941594

15951595
if not len(obj.columns):
15961596
# e.g. test_margins_no_values_no_cols
@@ -1605,19 +1605,19 @@ def _python_agg_general(self, func, *args, **kwargs):
16051605
res.columns = obj.columns.copy(deep=False)
16061606
return self._wrap_aggregated_output(res)
16071607

1608-
def _aggregate_frame(self, func, *args, **kwargs) -> DataFrame:
1608+
def _aggregate_frame(self, func, *args, numeric_only=False, **kwargs) -> DataFrame:
16091609
if self._grouper.nkeys != 1:
16101610
raise AssertionError("Number of keys must be 1")
16111611

1612-
obj = self._obj_with_exclusions
1613-
1612+
mgr = self._get_data_to_aggregate(numeric_only=numeric_only)
1613+
data = self._wrap_agged_manager(mgr)
16141614
result: dict[Hashable, NDFrame | np.ndarray] = {}
1615-
for name, grp_df in self._grouper.get_iterator(obj):
1615+
for name, grp_df in self._grouper.get_iterator(data):
16161616
fres = func(grp_df, *args, **kwargs)
16171617
result[name] = fres
16181618

16191619
result_index = self._grouper.result_index
1620-
out = self.obj._constructor(result, index=obj.columns, columns=result_index)
1620+
out = self.obj._constructor(result, index=data.columns, columns=result_index)
16211621
out = out.T
16221622

16231623
return out

pandas/tests/groupby/aggregate/test_aggregate.py

+50
Original file line numberDiff line numberDiff line change
@@ -1663,3 +1663,53 @@ def func(x):
16631663
msg = "length must not be 0"
16641664
with pytest.raises(ValueError, match=msg):
16651665
df.groupby("A", observed=False).agg(func)
1666+
1667+
1668+
@pytest.mark.parametrize(
1669+
"aggfunc",
1670+
[
1671+
"mean",
1672+
np.mean,
1673+
["sum", "mean"],
1674+
[np.sum, np.mean],
1675+
["sum", np.mean],
1676+
lambda x: x.mean(),
1677+
{"A": "mean"},
1678+
{"A": "mean", "B": "sum"},
1679+
{"A": np.mean},
1680+
],
1681+
ids=[
1682+
" string_mean ",
1683+
" numpy_mean ",
1684+
" list_of_str_and_str ",
1685+
" list_of_numpy_and_numpy ",
1686+
" list_of_str_and_numpy ",
1687+
" lambda ",
1688+
" dict_with_str ",
1689+
" dict with 2 vars ",
1690+
" dict with numpy",
1691+
],
1692+
)
1693+
@pytest.mark.parametrize(
1694+
"groupers",
1695+
["groupby1", "groupby2", ["groupby1", "groupby2"]],
1696+
ids=[" 1_grouper_str ", " 1_grouper_int ", " 2_groupers_str_and_int "],
1697+
)
1698+
@pytest.mark.parametrize(
1699+
"numeric_only", [True, None], ids=[" numeric_only True ", " no_numeric_only_arg "]
1700+
) # need to add other kwargs
1701+
def test_different_combinations_of_groupby_agg(aggfunc, groupers, numeric_only):
1702+
df = DataFrame(
1703+
{
1704+
"A": [1, 2, 3, 4, 5],
1705+
"B": [10, 20, 30, 40, 50],
1706+
"groupby1": ["diamond", "diamond", "spade", "spade", "spade"],
1707+
"groupby2": [1, 1, 1, 2, 2],
1708+
"attr": ["a", "b", "c", "d", "e"],
1709+
}
1710+
)
1711+
if numeric_only or isinstance(aggfunc, dict):
1712+
df.groupby(by=groupers).agg(func=aggfunc, numeric_only=numeric_only)
1713+
else:
1714+
with pytest.raises(TypeError):
1715+
df.groupby(by=groupers).agg(func=aggfunc)

test.tar

10 KB
Binary file not shown.

0 commit comments

Comments
 (0)