Skip to content

Commit 80f0a74

Browse files
authored
Align cython and python reduction code paths (#36459)
1 parent ff11c05 commit 80f0a74

File tree

3 files changed

+34
-21
lines changed

3 files changed

+34
-21
lines changed

pandas/_libs/reduction.pyx

+12-5
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@ from pandas._libs cimport util
1616
from pandas._libs.lib import is_scalar, maybe_convert_objects
1717

1818

19-
cdef _check_result_array(object obj, Py_ssize_t cnt):
19+
cpdef check_result_array(object obj, Py_ssize_t cnt):
2020

2121
if (util.is_array(obj) or
2222
(isinstance(obj, list) and len(obj) == cnt) or
2323
getattr(obj, 'shape', None) == (cnt,)):
24-
raise ValueError('Function does not reduce')
24+
raise ValueError('Must produce aggregated value')
2525

2626

2727
cdef class _BaseGrouper:
@@ -74,12 +74,14 @@ cdef class _BaseGrouper:
7474
cached_ityp._engine.clear_mapping()
7575
cached_ityp._cache.clear() # e.g. inferred_freq must go
7676
res = self.f(cached_typ)
77-
res = _extract_result(res)
77+
res = extract_result(res)
7878
if not initialized:
7979
# On the first pass, we check the output shape to see
8080
# if this looks like a reduction.
8181
initialized = True
82-
_check_result_array(res, len(self.dummy_arr))
82+
# In all tests other than test_series_grouper and
83+
# test_series_bin_grouper, we have len(self.dummy_arr) == 0
84+
check_result_array(res, len(self.dummy_arr))
8385

8486
return res, initialized
8587

@@ -278,9 +280,14 @@ cdef class SeriesGrouper(_BaseGrouper):
278280
return result, counts
279281

280282

281-
cdef inline _extract_result(object res, bint squeeze=True):
283+
cpdef inline extract_result(object res, bint squeeze=True):
282284
""" extract the result object, it might be a 0-dim ndarray
283285
or a len-1 0-dim, or a scalar """
286+
if hasattr(res, "_values"):
287+
# Preserve EA
288+
res = res._values
289+
if squeeze and res.ndim == 1 and len(res) == 1:
290+
res = res[0]
284291
if hasattr(res, 'values') and util.is_array(res.values):
285292
res = res.values
286293
if util.is_array(res):

pandas/core/groupby/generic.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
import numpy as np
3131

32-
from pandas._libs import lib
32+
from pandas._libs import lib, reduction as libreduction
3333
from pandas._typing import ArrayLike, FrameOrSeries, FrameOrSeriesUnion
3434
from pandas.util._decorators import Appender, Substitution, doc
3535

@@ -471,12 +471,19 @@ def _get_index() -> Index:
471471

472472
def _aggregate_named(self, func, *args, **kwargs):
473473
result = {}
474+
initialized = False
474475

475476
for name, group in self:
476-
group.name = name
477+
# Each step of this loop corresponds to
478+
# libreduction._BaseGrouper._apply_to_group
479+
group.name = name # NB: libreduction does not pin name
480+
477481
output = func(group, *args, **kwargs)
478-
if isinstance(output, (Series, Index, np.ndarray)):
479-
raise ValueError("Must produce aggregated value")
482+
output = libreduction.extract_result(output)
483+
if not initialized:
484+
# We only do this validation on the first iteration
485+
libreduction.check_result_array(output, 0)
486+
initialized = True
480487
result[name] = output
481488

482489
return result

pandas/core/groupby/ops.py

+11-12
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,7 @@ def agg_series(self, obj: Series, func: F):
623623
try:
624624
return self._aggregate_series_fast(obj, func)
625625
except ValueError as err:
626-
if "Function does not reduce" in str(err):
626+
if "Must produce aggregated value" in str(err):
627627
# raised in libreduction
628628
pass
629629
else:
@@ -653,27 +653,26 @@ def _aggregate_series_pure_python(self, obj: Series, func: F):
653653
group_index, _, ngroups = self.group_info
654654

655655
counts = np.zeros(ngroups, dtype=int)
656-
result = None
656+
result = np.empty(ngroups, dtype="O")
657+
initialized = False
657658

658659
splitter = get_splitter(obj, group_index, ngroups, axis=0)
659660

660661
for label, group in splitter:
662+
663+
# Each step of this loop corresponds to
664+
# libreduction._BaseGrouper._apply_to_group
661665
res = func(group)
666+
res = libreduction.extract_result(res)
662667

663-
if result is None:
664-
if isinstance(res, (Series, Index, np.ndarray)):
665-
if len(res) == 1:
666-
# e.g. test_agg_lambda_with_timezone lambda e: e.head(1)
667-
# FIXME: are we potentially losing important res.index info?
668-
res = res.item()
669-
else:
670-
raise ValueError("Function does not reduce")
671-
result = np.empty(ngroups, dtype="O")
668+
if not initialized:
669+
# We only do this validation on the first iteration
670+
libreduction.check_result_array(res, 0)
671+
initialized = True
672672

673673
counts[label] = group.shape[0]
674674
result[label] = res
675675

676-
assert result is not None
677676
result = lib.maybe_convert_objects(result, try_float=0)
678677
# TODO: maybe_cast_to_extension_array?
679678

0 commit comments

Comments
 (0)