From 15603c1bd1e97a1998661ecb545f2e09df9b0d66 Mon Sep 17 00:00:00 2001 From: Brock Date: Fri, 18 Sep 2020 11:00:53 -0700 Subject: [PATCH 1/2] REF: match cython and non-cython reduction behavior --- pandas/_libs/reduction.pyx | 17 ++++++++++++----- pandas/core/groupby/generic.py | 10 +++++++--- pandas/core/groupby/ops.py | 12 ++++-------- 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/pandas/_libs/reduction.pyx b/pandas/_libs/reduction.pyx index 8161b5c5c2b11..3a0fda5aed620 100644 --- a/pandas/_libs/reduction.pyx +++ b/pandas/_libs/reduction.pyx @@ -16,12 +16,12 @@ from pandas._libs cimport util from pandas._libs.lib import is_scalar, maybe_convert_objects -cdef _check_result_array(object obj, Py_ssize_t cnt): +cpdef check_result_array(object obj, Py_ssize_t cnt): if (util.is_array(obj) or (isinstance(obj, list) and len(obj) == cnt) or getattr(obj, 'shape', None) == (cnt,)): - raise ValueError('Function does not reduce') + raise ValueError('Must produce aggregated value') cdef class _BaseGrouper: @@ -74,12 +74,14 @@ cdef class _BaseGrouper: cached_ityp._engine.clear_mapping() cached_ityp._cache.clear() # e.g. inferred_freq must go res = self.f(cached_typ) - res = _extract_result(res) + res = extract_result(res) if not initialized: # On the first pass, we check the output shape to see # if this looks like a reduction. initialized = True - _check_result_array(res, len(self.dummy_arr)) + # In all tests other than test_series_grouper and + # test_series_bin_grouper, we have len(self.dummy_arr) == 0 + check_result_array(res, len(self.dummy_arr)) return res, initialized @@ -278,9 +280,14 @@ cdef class SeriesGrouper(_BaseGrouper): return result, counts -cdef inline _extract_result(object res, bint squeeze=True): +cpdef inline extract_result(object res, bint squeeze=True): """ extract the result object, it might be a 0-dim ndarray or a len-1 0-dim, or a scalar """ + if hasattr(res, "_values"): + # Preserve EA + res = res._values + if squeeze and res.ndim == 1 and len(res) == 1: + res = res[0] if hasattr(res, 'values') and util.is_array(res.values): res = res.values if util.is_array(res): diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index a931221ef3ce1..9c3e6da3ec7a1 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -29,7 +29,7 @@ import numpy as np -from pandas._libs import lib +from pandas._libs import lib, reduction as libreduction from pandas._typing import ArrayLike, FrameOrSeries, FrameOrSeriesUnion from pandas.util._decorators import Appender, Substitution, doc @@ -471,12 +471,16 @@ def _get_index() -> Index: def _aggregate_named(self, func, *args, **kwargs): result = {} + initialized = False for name, group in self: group.name = name output = func(group, *args, **kwargs) - if isinstance(output, (Series, Index, np.ndarray)): - raise ValueError("Must produce aggregated value") + output = libreduction.extract_result(output) + if not initialized: + # We only do this validation on the first iteration + libreduction.check_result_array(output, 0) + initialized = True result[name] = output return result diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index e9525f03368fa..832d969193e49 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -623,7 +623,7 @@ def agg_series(self, obj: Series, func: F): try: return self._aggregate_series_fast(obj, func) except ValueError as err: - if "Function does not reduce" in str(err): + if "Must produce aggregated value" in str(err): # raised in libreduction pass else: @@ -659,15 +659,11 @@ def _aggregate_series_pure_python(self, obj: Series, func: F): for label, group in splitter: res = func(group) + res = libreduction.extract_result(res) if result is None: - if isinstance(res, (Series, Index, np.ndarray)): - if len(res) == 1: - # e.g. test_agg_lambda_with_timezone lambda e: e.head(1) - # FIXME: are we potentially losing important res.index info? - res = res.item() - else: - raise ValueError("Function does not reduce") + # We only do this validation on the first iteration + libreduction.check_result_array(res, 0) result = np.empty(ngroups, dtype="O") counts[label] = group.shape[0] From 2702f899e7fb03348cb58d601188b26a802ca4d9 Mon Sep 17 00:00:00 2001 From: Brock Date: Fri, 18 Sep 2020 13:55:22 -0700 Subject: [PATCH 2/2] REF: match libreduction pattern --- pandas/core/groupby/generic.py | 5 ++++- pandas/core/groupby/ops.py | 11 +++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 9c3e6da3ec7a1..5075e08c78060 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -474,7 +474,10 @@ def _aggregate_named(self, func, *args, **kwargs): initialized = False for name, group in self: - group.name = name + # Each step of this loop corresponds to + # libreduction._BaseGrouper._apply_to_group + group.name = name # NB: libreduction does not pin name + output = func(group, *args, **kwargs) output = libreduction.extract_result(output) if not initialized: diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index 832d969193e49..b3f91d4623c84 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -653,23 +653,26 @@ def _aggregate_series_pure_python(self, obj: Series, func: F): group_index, _, ngroups = self.group_info counts = np.zeros(ngroups, dtype=int) - result = None + result = np.empty(ngroups, dtype="O") + initialized = False splitter = get_splitter(obj, group_index, ngroups, axis=0) for label, group in splitter: + + # Each step of this loop corresponds to + # libreduction._BaseGrouper._apply_to_group res = func(group) res = libreduction.extract_result(res) - if result is None: + if not initialized: # We only do this validation on the first iteration libreduction.check_result_array(res, 0) - result = np.empty(ngroups, dtype="O") + initialized = True counts[label] = group.shape[0] result[label] = res - assert result is not None result = lib.maybe_convert_objects(result, try_float=0) # TODO: maybe_cast_to_extension_array?