Skip to content

Commit 652177f

Browse files
jbrockmendelyeshsurya
authored andcommitted
REF: simpilify _cython_agg_general (pandas-dev#41271)
1 parent 1a5ac31 commit 652177f

File tree

2 files changed

+34
-42
lines changed

2 files changed

+34
-42
lines changed

pandas/core/groupby/generic.py

+34-33
Original file line numberDiff line numberDiff line change
@@ -345,47 +345,48 @@ def _aggregate_multiple_funcs(self, arg):
345345
def _cython_agg_general(
346346
self, how: str, alt=None, numeric_only: bool = True, min_count: int = -1
347347
):
348-
output: dict[base.OutputKey, ArrayLike] = {}
349-
# Ideally we would be able to enumerate self._iterate_slices and use
350-
# the index from enumeration as the key of output, but ohlc in particular
351-
# returns a (n x 4) array. Output requires 1D ndarrays as values, so we
352-
# need to slice that up into 1D arrays
353-
idx = 0
354-
for obj in self._iterate_slices():
355-
name = obj.name
356-
is_numeric = is_numeric_dtype(obj.dtype)
357-
if numeric_only and not is_numeric:
358-
continue
359-
360-
objvals = obj._values
361348

362-
if isinstance(objvals, Categorical):
363-
if self.grouper.ngroups > 0:
364-
# without special-casing, we would raise, then in fallback
365-
# would eventually call agg_series but without re-casting
366-
# to Categorical
367-
# equiv: res_values, _ = self.grouper.agg_series(obj, alt)
368-
res_values, _ = self.grouper._aggregate_series_pure_python(obj, alt)
369-
else:
370-
# equiv: res_values = self._python_agg_general(alt)
371-
res_values = self._python_apply_general(alt, self._selected_obj)
349+
obj = self._selected_obj
350+
objvals = obj._values
372351

373-
result = type(objvals)._from_sequence(res_values, dtype=objvals.dtype)
352+
if numeric_only and not is_numeric_dtype(obj.dtype):
353+
raise DataError("No numeric types to aggregate")
374354

375-
else:
355+
# This is overkill because it is only called once, but is here to
356+
# mirror the array_func used in DataFrameGroupBy._cython_agg_general
357+
def array_func(values: ArrayLike) -> ArrayLike:
358+
try:
376359
result = self.grouper._cython_operation(
377-
"aggregate", obj._values, how, axis=0, min_count=min_count
360+
"aggregate", values, how, axis=0, min_count=min_count
378361
)
362+
except NotImplementedError:
363+
ser = Series(values) # equiv 'obj' from outer frame
364+
if self.grouper.ngroups > 0:
365+
res_values, _ = self.grouper.agg_series(ser, alt)
366+
else:
367+
# equiv: res_values = self._python_agg_general(alt)
368+
# error: Incompatible types in assignment (expression has
369+
# type "Union[DataFrame, Series]", variable has type
370+
# "Union[ExtensionArray, ndarray]")
371+
res_values = self._python_apply_general( # type: ignore[assignment]
372+
alt, ser
373+
)
379374

380-
assert result.ndim == 1
381-
key = base.OutputKey(label=name, position=idx)
382-
output[key] = result
383-
idx += 1
375+
if isinstance(values, Categorical):
376+
# Because we only get here with known dtype-preserving
377+
# reductions, we cast back to Categorical.
378+
# TODO: if we ever get "rank" working, exclude it here.
379+
result = type(values)._from_sequence(res_values, dtype=values.dtype)
380+
else:
381+
result = res_values
382+
return result
384383

385-
if not output:
386-
raise DataError("No numeric types to aggregate")
384+
result = array_func(objvals)
387385

388-
return self._wrap_aggregated_output(output)
386+
ser = self.obj._constructor(
387+
result, index=self.grouper.result_index, name=obj.name
388+
)
389+
return self._reindex_output(ser)
389390

390391
def _wrap_aggregated_output(
391392
self,

pandas/core/groupby/groupby.py

-9
Original file line numberDiff line numberDiff line change
@@ -1281,15 +1281,6 @@ def _agg_general(
12811281
)
12821282
except DataError:
12831283
pass
1284-
except NotImplementedError as err:
1285-
if "function is not implemented for this dtype" in str(
1286-
err
1287-
) or "category dtype not supported" in str(err):
1288-
# raised in _get_cython_function, in some cases can
1289-
# be trimmed by implementing cython funcs for more dtypes
1290-
pass
1291-
else:
1292-
raise
12931284

12941285
# apply a non-cython aggregation
12951286
if result is None:

0 commit comments

Comments
 (0)