From 0d4302858c61acc3ad35c8d673238e3d4fe2798b Mon Sep 17 00:00:00 2001 From: Brock Date: Thu, 16 Feb 2023 20:16:46 -0800 Subject: [PATCH] REF: simplify python_agg_general --- pandas/core/groupby/generic.py | 68 ++++++++++++++++++++++++++-------- pandas/core/groupby/groupby.py | 28 -------------- 2 files changed, 53 insertions(+), 43 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 3a9ffc8631441..19fba398feb08 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -250,14 +250,28 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs) if cyfunc and not args and not kwargs: return getattr(self, cyfunc)() + if self.ngroups == 0: + # e.g. test_evaluate_with_empty_groups without any groups to + # iterate over, we have no output on which to do dtype + # inference. We default to using the existing dtype. + # xref GH#51445 + obj = self._obj_with_exclusions + return self.obj._constructor( + [], + name=self.obj.name, + index=self.grouper.result_index, + dtype=obj.dtype, + ) + if self.grouper.nkeys > 1: return self._python_agg_general(func, *args, **kwargs) try: return self._python_agg_general(func, *args, **kwargs) except KeyError: - # TODO: KeyError is raised in _python_agg_general, - # see test_groupby.test_basic + # KeyError raised in test_groupby.test_basic is bc the func does + # a dictionary lookup on group.name, but group name is not + # pinned in _python_agg_general, only in _aggregate_named result = self._aggregate_named(func, *args, **kwargs) # result is a dict whose keys are the elements of result_index @@ -267,6 +281,15 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs) agg = aggregate + def _python_agg_general(self, func, *args, **kwargs): + func = com.is_builtin_func(func) + f = lambda x: func(x, *args, **kwargs) + + obj = self._obj_with_exclusions + result = self.grouper.agg_series(obj, f) + res = obj._constructor(result, name=obj.name) + return self._wrap_aggregated_output(res) + def _aggregate_multiple_funcs(self, arg, *args, **kwargs) -> DataFrame: if isinstance(arg, dict): if self.as_index: @@ -308,18 +331,6 @@ def _aggregate_multiple_funcs(self, arg, *args, **kwargs) -> DataFrame: output = self._reindex_output(output) return output - def _indexed_output_to_ndframe( - self, output: Mapping[base.OutputKey, ArrayLike] - ) -> Series: - """ - Wrap the dict result of a GroupBy aggregation into a Series. - """ - assert len(output) == 1 - values = next(iter(output.values())) - result = self.obj._constructor(values) - result.name = self.obj.name - return result - def _wrap_applied_output( self, data: Series, @@ -1319,6 +1330,31 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs) agg = aggregate + def _python_agg_general(self, func, *args, **kwargs): + func = com.is_builtin_func(func) + f = lambda x: func(x, *args, **kwargs) + + # iterate through "columns" ex exclusions to populate output dict + output: dict[base.OutputKey, ArrayLike] = {} + + if self.ngroups == 0: + # e.g. test_evaluate_with_empty_groups different path gets different + # result dtype in empty case. + return self._python_apply_general(f, self._selected_obj, is_agg=True) + + for idx, obj in enumerate(self._iterate_slices()): + name = obj.name + result = self.grouper.agg_series(obj, f) + key = base.OutputKey(label=name, position=idx) + output[key] = result + + if not output: + # e.g. test_margins_no_values_no_cols + return self._python_apply_general(f, self._selected_obj) + + res = self._indexed_output_to_ndframe(output) + return self._wrap_aggregated_output(res) + def _iterate_slices(self) -> Iterable[Series]: obj = self._selected_obj if self.axis == 1: @@ -1885,7 +1921,9 @@ def nunique(self, dropna: bool = True) -> DataFrame: if self.axis != 0: # see test_groupby_crash_on_nunique - return self._python_agg_general(lambda sgb: sgb.nunique(dropna)) + return self._python_apply_general( + lambda sgb: sgb.nunique(dropna), self._obj_with_exclusions, is_agg=True + ) obj = self._obj_with_exclusions results = self._apply_to_column_groupbys( diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 0f0c1daf2127b..dee68c01587b1 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1409,34 +1409,6 @@ def _python_apply_general( is_transform, ) - # TODO: I (jbrockmendel) think this should be equivalent to doing grouped_reduce - # on _agg_py_fallback, but trying that here fails a bunch of tests 2023-02-07. - @final - def _python_agg_general(self, func, *args, **kwargs): - func = com.is_builtin_func(func) - f = lambda x: func(x, *args, **kwargs) - - # iterate through "columns" ex exclusions to populate output dict - output: dict[base.OutputKey, ArrayLike] = {} - - if self.ngroups == 0: - # e.g. test_evaluate_with_empty_groups different path gets different - # result dtype in empty case. - return self._python_apply_general(f, self._selected_obj, is_agg=True) - - for idx, obj in enumerate(self._iterate_slices()): - name = obj.name - result = self.grouper.agg_series(obj, f) - key = base.OutputKey(label=name, position=idx) - output[key] = result - - if not output: - # e.g. test_groupby_crash_on_nunique, test_margins_no_values_no_cols - return self._python_apply_general(f, self._selected_obj) - - res = self._indexed_output_to_ndframe(output) - return self._wrap_aggregated_output(res) - @final def _agg_general( self,