Skip to content

Commit 7171b76

Browse files
jbrockmendelJulianWgs
authored andcommitted
REF: De-duplicate agg_series (pandas-dev#41210)
1 parent 61608b9 commit 7171b76

File tree

3 files changed

+16
-18
lines changed

3 files changed

+16
-18
lines changed

pandas/core/groupby/generic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1686,7 +1686,7 @@ def _insert_inaxis_grouper_inplace(self, result: DataFrame) -> None:
16861686

16871687
def _wrap_aggregated_output(
16881688
self,
1689-
output: Mapping[base.OutputKey, Series | np.ndarray],
1689+
output: Mapping[base.OutputKey, Series | ArrayLike],
16901690
) -> DataFrame:
16911691
"""
16921692
Wraps the output of DataFrameGroupBy aggregations into the expected result.

pandas/core/groupby/groupby.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1061,7 +1061,7 @@ def _set_result_index_ordered(
10611061

10621062
return result
10631063

1064-
def _wrap_aggregated_output(self, output: Mapping[base.OutputKey, np.ndarray]):
1064+
def _wrap_aggregated_output(self, output: Mapping[base.OutputKey, ArrayLike]):
10651065
raise AbstractMethodError(self)
10661066

10671067
def _wrap_transformed_output(self, output: Mapping[base.OutputKey, ArrayLike]):
@@ -1227,7 +1227,7 @@ def _python_agg_general(self, func, *args, **kwargs):
12271227
f = lambda x: func(x, *args, **kwargs)
12281228

12291229
# iterate through "columns" ex exclusions to populate output dict
1230-
output: dict[base.OutputKey, np.ndarray] = {}
1230+
output: dict[base.OutputKey, ArrayLike] = {}
12311231

12321232
for idx, obj in enumerate(self._iterate_slices()):
12331233
name = obj.name

pandas/core/groupby/ops.py

+13-15
Original file line numberDiff line numberDiff line change
@@ -923,7 +923,8 @@ def _cython_operation(
923923
**kwargs,
924924
)
925925

926-
def agg_series(self, obj: Series, func: F):
926+
@final
927+
def agg_series(self, obj: Series, func: F) -> tuple[ArrayLike, np.ndarray]:
927928
# Caller is responsible for checking ngroups != 0
928929
assert self.ngroups != 0
929930

@@ -952,8 +953,9 @@ def agg_series(self, obj: Series, func: F):
952953
raise
953954
return self._aggregate_series_pure_python(obj, func)
954955

955-
@final
956-
def _aggregate_series_fast(self, obj: Series, func: F):
956+
def _aggregate_series_fast(
957+
self, obj: Series, func: F
958+
) -> tuple[ArrayLike, np.ndarray]:
957959
# At this point we have already checked that
958960
# - obj.index is not a MultiIndex
959961
# - obj is backed by an ndarray, not ExtensionArray
@@ -1157,18 +1159,14 @@ def groupings(self) -> list[grouper.Grouping]:
11571159
for lvl, name in zip(self.levels, self.names)
11581160
]
11591161

1160-
def agg_series(self, obj: Series, func: F):
1161-
# Caller is responsible for checking ngroups != 0
1162-
assert self.ngroups != 0
1163-
assert len(self.bins) > 0 # otherwise we'd get IndexError in get_result
1164-
1165-
if is_extension_array_dtype(obj.dtype):
1166-
# preempt SeriesBinGrouper from raising TypeError
1167-
return self._aggregate_series_pure_python(obj, func)
1168-
1169-
elif obj.index._has_complex_internals:
1170-
return self._aggregate_series_pure_python(obj, func)
1171-
1162+
def _aggregate_series_fast(
1163+
self, obj: Series, func: F
1164+
) -> tuple[ArrayLike, np.ndarray]:
1165+
# At this point we have already checked that
1166+
# - obj.index is not a MultiIndex
1167+
# - obj is backed by an ndarray, not ExtensionArray
1168+
# - ngroups != 0
1169+
# - len(self.bins) > 0
11721170
grouper = libreduction.SeriesBinGrouper(obj, func, self.bins)
11731171
return grouper.get_result()
11741172

0 commit comments

Comments
 (0)