Skip to content

Commit 06eb8db

Browse files
authored
CLN: Some groupby internals (pandas-dev#31915)
* CLN: Some groupby internals * Additional annotation
1 parent 05ab8ba commit 06eb8db

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

pandas/core/groupby/ops.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def apply(self, f, data: FrameOrSeries, axis: int = 0):
169169
and not sdata.index._has_complex_internals
170170
):
171171
try:
172-
result_values, mutated = splitter.fast_apply(f, group_keys)
172+
result_values, mutated = splitter.fast_apply(f, sdata, group_keys)
173173

174174
except libreduction.InvalidApply as err:
175175
# This Exception is raised if `f` triggers an exception
@@ -925,11 +925,9 @@ def _chop(self, sdata: Series, slice_obj: slice) -> Series:
925925

926926

927927
class FrameSplitter(DataSplitter):
928-
def fast_apply(self, f, names):
928+
def fast_apply(self, f, sdata: FrameOrSeries, names):
929929
# must return keys::list, values::list, mutated::bool
930930
starts, ends = lib.generate_slices(self.slabels, self.ngroups)
931-
932-
sdata = self._get_sorted_data()
933931
return libreduction.apply_frame_axis0(sdata, f, names, starts, ends)
934932

935933
def _chop(self, sdata: DataFrame, slice_obj: slice) -> DataFrame:
@@ -939,11 +937,13 @@ def _chop(self, sdata: DataFrame, slice_obj: slice) -> DataFrame:
939937
return sdata.iloc[:, slice_obj]
940938

941939

942-
def get_splitter(data: FrameOrSeries, *args, **kwargs) -> DataSplitter:
940+
def get_splitter(
941+
data: FrameOrSeries, labels: np.ndarray, ngroups: int, axis: int = 0
942+
) -> DataSplitter:
943943
if isinstance(data, Series):
944944
klass: Type[DataSplitter] = SeriesSplitter
945945
else:
946946
# i.e. DataFrame
947947
klass = FrameSplitter
948948

949-
return klass(data, *args, **kwargs)
949+
return klass(data, labels, ngroups, axis)

pandas/tests/groupby/test_apply.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,9 @@ def f(g):
108108

109109
splitter = grouper._get_splitter(g._selected_obj, axis=g.axis)
110110
group_keys = grouper._get_group_keys()
111+
sdata = splitter._get_sorted_data()
111112

112-
values, mutated = splitter.fast_apply(f, group_keys)
113+
values, mutated = splitter.fast_apply(f, sdata, group_keys)
113114

114115
assert not mutated
115116

0 commit comments

Comments
 (0)