Skip to content

Commit 23e2841

Browse files
authored
PERF: cache sortings in groupby.ops (#51792)
* PERF: cache sort_idx, sorted_ids * move apply_groupwise down * re-use cached soted_idx/sorted_ids
1 parent 7e88054 commit 23e2841

File tree

2 files changed

+81
-54
lines changed

2 files changed

+81
-54
lines changed

pandas/core/groupby/groupby.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1167,8 +1167,8 @@ def _wrap_applied_output(
11671167
@final
11681168
def _numba_prep(self, data: DataFrame):
11691169
ids, _, ngroups = self.grouper.group_info
1170-
sorted_index = get_group_index_sorter(ids, ngroups)
1171-
sorted_ids = algorithms.take_nd(ids, sorted_index, allow_fill=False)
1170+
sorted_index = self.grouper._sort_idx
1171+
sorted_ids = self.grouper._sorted_ids
11721172

11731173
sorted_data = data.take(sorted_index, axis=self.axis).to_numpy()
11741174
if len(self.grouper.groupings) > 1:

pandas/core/groupby/ops.py

+79-52
Original file line numberDiff line numberDiff line change
@@ -734,7 +734,14 @@ def _get_splitter(self, data: NDFrame, axis: AxisInt = 0) -> DataSplitter:
734734
Generator yielding subsetted objects
735735
"""
736736
ids, _, ngroups = self.group_info
737-
return _get_splitter(data, ids, ngroups, axis=axis)
737+
return _get_splitter(
738+
data,
739+
ids,
740+
ngroups,
741+
sorted_ids=self._sorted_ids,
742+
sort_idx=self._sort_idx,
743+
axis=axis,
744+
)
738745

739746
@final
740747
@cache_readonly
@@ -747,45 +754,6 @@ def group_keys_seq(self):
747754
# provide "flattened" iterator for multi-group setting
748755
return get_flattened_list(ids, ngroups, self.levels, self.codes)
749756

750-
@final
751-
def apply_groupwise(
752-
self, f: Callable, data: DataFrame | Series, axis: AxisInt = 0
753-
) -> tuple[list, bool]:
754-
mutated = False
755-
splitter = self._get_splitter(data, axis=axis)
756-
group_keys = self.group_keys_seq
757-
result_values = []
758-
759-
# This calls DataSplitter.__iter__
760-
zipped = zip(group_keys, splitter)
761-
762-
for key, group in zipped:
763-
# Pinning name is needed for
764-
# test_group_apply_once_per_group,
765-
# test_inconsistent_return_type, test_set_group_name,
766-
# test_group_name_available_in_inference_pass,
767-
# test_groupby_multi_timezone
768-
object.__setattr__(group, "name", key)
769-
770-
# group might be modified
771-
group_axes = group.axes
772-
res = f(group)
773-
if not mutated and not _is_indexed_like(res, group_axes, axis):
774-
mutated = True
775-
result_values.append(res)
776-
# getattr pattern for __name__ is needed for functools.partial objects
777-
if len(group_keys) == 0 and getattr(f, "__name__", None) in [
778-
"skew",
779-
"sum",
780-
"prod",
781-
]:
782-
# If group_keys is empty, then no function calls have been made,
783-
# so we will not have raised even if this is an invalid dtype.
784-
# So do one dummy call here to raise appropriate TypeError.
785-
f(data.iloc[:0])
786-
787-
return result_values, mutated
788-
789757
@cache_readonly
790758
def indices(self) -> dict[Hashable, npt.NDArray[np.intp]]:
791759
"""dict {group name -> group indices}"""
@@ -1029,6 +997,61 @@ def _aggregate_series_pure_python(
1029997

1030998
return result
1031999

1000+
@final
1001+
def apply_groupwise(
1002+
self, f: Callable, data: DataFrame | Series, axis: AxisInt = 0
1003+
) -> tuple[list, bool]:
1004+
mutated = False
1005+
splitter = self._get_splitter(data, axis=axis)
1006+
group_keys = self.group_keys_seq
1007+
result_values = []
1008+
1009+
# This calls DataSplitter.__iter__
1010+
zipped = zip(group_keys, splitter)
1011+
1012+
for key, group in zipped:
1013+
# Pinning name is needed for
1014+
# test_group_apply_once_per_group,
1015+
# test_inconsistent_return_type, test_set_group_name,
1016+
# test_group_name_available_in_inference_pass,
1017+
# test_groupby_multi_timezone
1018+
object.__setattr__(group, "name", key)
1019+
1020+
# group might be modified
1021+
group_axes = group.axes
1022+
res = f(group)
1023+
if not mutated and not _is_indexed_like(res, group_axes, axis):
1024+
mutated = True
1025+
result_values.append(res)
1026+
# getattr pattern for __name__ is needed for functools.partial objects
1027+
if len(group_keys) == 0 and getattr(f, "__name__", None) in [
1028+
"skew",
1029+
"sum",
1030+
"prod",
1031+
]:
1032+
# If group_keys is empty, then no function calls have been made,
1033+
# so we will not have raised even if this is an invalid dtype.
1034+
# So do one dummy call here to raise appropriate TypeError.
1035+
f(data.iloc[:0])
1036+
1037+
return result_values, mutated
1038+
1039+
# ------------------------------------------------------------
1040+
# Methods for sorting subsets of our GroupBy's object
1041+
1042+
@final
1043+
@cache_readonly
1044+
def _sort_idx(self) -> npt.NDArray[np.intp]:
1045+
# Counting sort indexer
1046+
ids, _, ngroups = self.group_info
1047+
return get_group_index_sorter(ids, ngroups)
1048+
1049+
@final
1050+
@cache_readonly
1051+
def _sorted_ids(self) -> npt.NDArray[np.intp]:
1052+
ids, _, _ = self.group_info
1053+
return ids.take(self._sort_idx)
1054+
10321055

10331056
class BinGrouper(BaseGrouper):
10341057
"""
@@ -1211,25 +1234,21 @@ def __init__(
12111234
data: NDFrameT,
12121235
labels: npt.NDArray[np.intp],
12131236
ngroups: int,
1237+
*,
1238+
sort_idx: npt.NDArray[np.intp],
1239+
sorted_ids: npt.NDArray[np.intp],
12141240
axis: AxisInt = 0,
12151241
) -> None:
12161242
self.data = data
12171243
self.labels = ensure_platform_int(labels) # _should_ already be np.intp
12181244
self.ngroups = ngroups
12191245

1246+
self._slabels = sorted_ids
1247+
self._sort_idx = sort_idx
1248+
12201249
self.axis = axis
12211250
assert isinstance(axis, int), axis
12221251

1223-
@cache_readonly
1224-
def _slabels(self) -> npt.NDArray[np.intp]:
1225-
# Sorted labels
1226-
return self.labels.take(self._sort_idx)
1227-
1228-
@cache_readonly
1229-
def _sort_idx(self) -> npt.NDArray[np.intp]:
1230-
# Counting sort indexer
1231-
return get_group_index_sorter(self.labels, self.ngroups)
1232-
12331252
def __iter__(self) -> Iterator:
12341253
sdata = self._sorted_data
12351254

@@ -1272,12 +1291,20 @@ def _chop(self, sdata: DataFrame, slice_obj: slice) -> DataFrame:
12721291

12731292

12741293
def _get_splitter(
1275-
data: NDFrame, labels: np.ndarray, ngroups: int, axis: AxisInt = 0
1294+
data: NDFrame,
1295+
labels: npt.NDArray[np.intp],
1296+
ngroups: int,
1297+
*,
1298+
sort_idx: npt.NDArray[np.intp],
1299+
sorted_ids: npt.NDArray[np.intp],
1300+
axis: AxisInt = 0,
12761301
) -> DataSplitter:
12771302
if isinstance(data, Series):
12781303
klass: type[DataSplitter] = SeriesSplitter
12791304
else:
12801305
# i.e. DataFrame
12811306
klass = FrameSplitter
12821307

1283-
return klass(data, labels, ngroups, axis)
1308+
return klass(
1309+
data, labels, ngroups, sort_idx=sort_idx, sorted_ids=sorted_ids, axis=axis
1310+
)

0 commit comments

Comments
 (0)