diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 457352564f255..96f39bb99e544 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1167,8 +1167,8 @@ def _wrap_applied_output( @final def _numba_prep(self, data: DataFrame): ids, _, ngroups = self.grouper.group_info - sorted_index = get_group_index_sorter(ids, ngroups) - sorted_ids = algorithms.take_nd(ids, sorted_index, allow_fill=False) + sorted_index = self.grouper._sort_idx + sorted_ids = self.grouper._sorted_ids sorted_data = data.take(sorted_index, axis=self.axis).to_numpy() if len(self.grouper.groupings) > 1: diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index 9a06a3da28e15..0acc7fe29b5db 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -734,7 +734,14 @@ def _get_splitter(self, data: NDFrame, axis: AxisInt = 0) -> DataSplitter: Generator yielding subsetted objects """ ids, _, ngroups = self.group_info - return _get_splitter(data, ids, ngroups, axis=axis) + return _get_splitter( + data, + ids, + ngroups, + sorted_ids=self._sorted_ids, + sort_idx=self._sort_idx, + axis=axis, + ) @final @cache_readonly @@ -747,45 +754,6 @@ def group_keys_seq(self): # provide "flattened" iterator for multi-group setting return get_flattened_list(ids, ngroups, self.levels, self.codes) - @final - def apply_groupwise( - self, f: Callable, data: DataFrame | Series, axis: AxisInt = 0 - ) -> tuple[list, bool]: - mutated = False - splitter = self._get_splitter(data, axis=axis) - group_keys = self.group_keys_seq - result_values = [] - - # This calls DataSplitter.__iter__ - zipped = zip(group_keys, splitter) - - for key, group in zipped: - # Pinning name is needed for - # test_group_apply_once_per_group, - # test_inconsistent_return_type, test_set_group_name, - # test_group_name_available_in_inference_pass, - # test_groupby_multi_timezone - object.__setattr__(group, "name", key) - - # group might be modified - group_axes = group.axes - res = f(group) - if not mutated and not _is_indexed_like(res, group_axes, axis): - mutated = True - result_values.append(res) - # getattr pattern for __name__ is needed for functools.partial objects - if len(group_keys) == 0 and getattr(f, "__name__", None) in [ - "skew", - "sum", - "prod", - ]: - # If group_keys is empty, then no function calls have been made, - # so we will not have raised even if this is an invalid dtype. - # So do one dummy call here to raise appropriate TypeError. - f(data.iloc[:0]) - - return result_values, mutated - @cache_readonly def indices(self) -> dict[Hashable, npt.NDArray[np.intp]]: """dict {group name -> group indices}""" @@ -1029,6 +997,61 @@ def _aggregate_series_pure_python( return result + @final + def apply_groupwise( + self, f: Callable, data: DataFrame | Series, axis: AxisInt = 0 + ) -> tuple[list, bool]: + mutated = False + splitter = self._get_splitter(data, axis=axis) + group_keys = self.group_keys_seq + result_values = [] + + # This calls DataSplitter.__iter__ + zipped = zip(group_keys, splitter) + + for key, group in zipped: + # Pinning name is needed for + # test_group_apply_once_per_group, + # test_inconsistent_return_type, test_set_group_name, + # test_group_name_available_in_inference_pass, + # test_groupby_multi_timezone + object.__setattr__(group, "name", key) + + # group might be modified + group_axes = group.axes + res = f(group) + if not mutated and not _is_indexed_like(res, group_axes, axis): + mutated = True + result_values.append(res) + # getattr pattern for __name__ is needed for functools.partial objects + if len(group_keys) == 0 and getattr(f, "__name__", None) in [ + "skew", + "sum", + "prod", + ]: + # If group_keys is empty, then no function calls have been made, + # so we will not have raised even if this is an invalid dtype. + # So do one dummy call here to raise appropriate TypeError. + f(data.iloc[:0]) + + return result_values, mutated + + # ------------------------------------------------------------ + # Methods for sorting subsets of our GroupBy's object + + @final + @cache_readonly + def _sort_idx(self) -> npt.NDArray[np.intp]: + # Counting sort indexer + ids, _, ngroups = self.group_info + return get_group_index_sorter(ids, ngroups) + + @final + @cache_readonly + def _sorted_ids(self) -> npt.NDArray[np.intp]: + ids, _, _ = self.group_info + return ids.take(self._sort_idx) + class BinGrouper(BaseGrouper): """ @@ -1211,25 +1234,21 @@ def __init__( data: NDFrameT, labels: npt.NDArray[np.intp], ngroups: int, + *, + sort_idx: npt.NDArray[np.intp], + sorted_ids: npt.NDArray[np.intp], axis: AxisInt = 0, ) -> None: self.data = data self.labels = ensure_platform_int(labels) # _should_ already be np.intp self.ngroups = ngroups + self._slabels = sorted_ids + self._sort_idx = sort_idx + self.axis = axis assert isinstance(axis, int), axis - @cache_readonly - def _slabels(self) -> npt.NDArray[np.intp]: - # Sorted labels - return self.labels.take(self._sort_idx) - - @cache_readonly - def _sort_idx(self) -> npt.NDArray[np.intp]: - # Counting sort indexer - return get_group_index_sorter(self.labels, self.ngroups) - def __iter__(self) -> Iterator: sdata = self._sorted_data @@ -1272,7 +1291,13 @@ def _chop(self, sdata: DataFrame, slice_obj: slice) -> DataFrame: def _get_splitter( - data: NDFrame, labels: np.ndarray, ngroups: int, axis: AxisInt = 0 + data: NDFrame, + labels: npt.NDArray[np.intp], + ngroups: int, + *, + sort_idx: npt.NDArray[np.intp], + sorted_ids: npt.NDArray[np.intp], + axis: AxisInt = 0, ) -> DataSplitter: if isinstance(data, Series): klass: type[DataSplitter] = SeriesSplitter @@ -1280,4 +1305,6 @@ def _get_splitter( # i.e. DataFrame klass = FrameSplitter - return klass(data, labels, ngroups, axis) + return klass( + data, labels, ngroups, sort_idx=sort_idx, sorted_ids=sorted_ids, axis=axis + )