From 2c7038dc6f38f742b4ab4dc45503bae621ebc96d Mon Sep 17 00:00:00 2001 From: Brock Date: Thu, 9 Sep 2021 10:13:28 -0700 Subject: [PATCH] REF: Groupby.pad/backfill operate blockwise --- pandas/_libs/groupby.pyi | 2 +- pandas/_libs/groupby.pyx | 4 +- pandas/core/groupby/groupby.py | 66 ++++++++++++++++++++-------- pandas/tests/groupby/test_missing.py | 2 + 4 files changed, 53 insertions(+), 21 deletions(-) diff --git a/pandas/_libs/groupby.pyi b/pandas/_libs/groupby.pyi index b363524e4e592..42bb1621a53bc 100644 --- a/pandas/_libs/groupby.pyi +++ b/pandas/_libs/groupby.pyi @@ -32,7 +32,7 @@ def group_shift_indexer( periods: int, ) -> None: ... def group_fillna_indexer( - out: np.ndarray, # ndarray[int64_t] + out: np.ndarray, # ndarray[intp_t] labels: np.ndarray, # ndarray[int64_t] mask: np.ndarray, # ndarray[uint8_t] direction: Literal["ffill", "bfill"], diff --git a/pandas/_libs/groupby.pyx b/pandas/_libs/groupby.pyx index 40e1049c39588..7b18e238ba195 100644 --- a/pandas/_libs/groupby.pyx +++ b/pandas/_libs/groupby.pyx @@ -321,7 +321,7 @@ def group_shift_indexer(int64_t[::1] out, const intp_t[::1] labels, @cython.wraparound(False) @cython.boundscheck(False) -def group_fillna_indexer(ndarray[int64_t] out, ndarray[intp_t] labels, +def group_fillna_indexer(ndarray[intp_t] out, ndarray[intp_t] labels, ndarray[uint8_t] mask, str direction, int64_t limit, bint dropna) -> None: """ @@ -329,7 +329,7 @@ def group_fillna_indexer(ndarray[int64_t] out, ndarray[intp_t] labels, Parameters ---------- - out : np.ndarray[np.int64] + out : np.ndarray[np.intp] Values into which this method will write its results. labels : np.ndarray[np.intp] Array containing unique label for each group, with its ordering diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 9eea81d1aa152..7c6b47c63c7fc 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1161,7 +1161,10 @@ def _wrap_transformed_output( Series or DataFrame Series for SeriesGroupBy, DataFrame for DataFrameGroupBy """ - result = self._indexed_output_to_ndframe(output) + if isinstance(output, (Series, DataFrame)): + result = output + else: + result = self._indexed_output_to_ndframe(output) if self.axis == 1: # Only relevant for DataFrameGroupBy @@ -2237,17 +2240,55 @@ def _fill(self, direction: Literal["ffill", "bfill"], limit=None): if limit is None: limit = -1 - return self._get_cythonized_result( + ids, _, _ = self.grouper.group_info + + col_func = partial( libgroupby.group_fillna_indexer, - numeric_only=False, - needs_mask=True, - cython_dtype=np.dtype(np.int64), - result_is_index=True, + labels=ids, direction=direction, limit=limit, dropna=self.dropna, ) + def blk_func(values: ArrayLike) -> ArrayLike: + mask = isna(values) + if values.ndim == 1: + indexer = np.empty(values.shape, dtype=np.intp) + col_func(out=indexer, mask=mask) + return algorithms.take_nd(values, indexer) + + else: + # We broadcast algorithms.take_nd analogous to + # np.take_along_axis + + # Note: we only get here with backfill/pad, + # so if we have a dtype that cannot hold NAs, + # then there will be no -1s in indexer, so we can use + # the original dtype (no need to ensure_dtype_can_hold_na) + if isinstance(values, np.ndarray): + out = np.empty(values.shape, dtype=values.dtype) + else: + out = type(values)._empty(values.shape, dtype=values.dtype) + + for i in range(len(values)): + # call group_fillna_indexer column-wise + indexer = np.empty(values.shape[1], dtype=np.intp) + col_func(out=indexer, mask=mask[i]) + out[i, :] = algorithms.take_nd(values[i], indexer) + return out + + obj = self._obj_with_exclusions + if self.axis == 1: + obj = obj.T + mgr = obj._mgr + res_mgr = mgr.apply(blk_func) + + new_obj = obj._constructor(res_mgr) + if isinstance(new_obj, Series): + new_obj.name = obj.name + + return self._wrap_transformed_output(new_obj) + @final @Substitution(name="groupby") def pad(self, limit=None): @@ -2920,7 +2961,6 @@ def _get_cythonized_result( min_count: int | None = None, needs_mask: bool = False, needs_ngroups: bool = False, - result_is_index: bool = False, pre_processing=None, post_processing=None, fill_value=None, @@ -2957,9 +2997,6 @@ def _get_cythonized_result( needs_nullable : bool, default False Whether a bool specifying if the input is nullable is part of the Cython call signature - result_is_index : bool, default False - Whether the result of the Cython operation is an index of - values to be retrieved, instead of the actual values themselves pre_processing : function, default None Function to be applied to `values` prior to passing to Cython. Function should return a tuple where the first element is the @@ -2985,8 +3022,6 @@ def _get_cythonized_result( """ numeric_only = self._resolve_numeric_only(numeric_only) - if result_is_index and aggregate: - raise ValueError("'result_is_index' and 'aggregate' cannot both be True!") if post_processing and not callable(post_processing): raise ValueError("'post_processing' must be a callable!") if pre_processing: @@ -3057,14 +3092,9 @@ def blk_func(values: ArrayLike) -> ArrayLike: func(**kwargs) # Call func to modify indexer values in place - if result_is_index: - result = algorithms.take_nd(values, result, fill_value=fill_value) - if real_2d and values.ndim == 1: assert result.shape[1] == 1, result.shape - # error: No overload variant of "__getitem__" of "ExtensionArray" - # matches argument type "Tuple[slice, int]" - result = result[:, 0] # type: ignore[call-overload] + result = result[:, 0] if needs_mask: mask = mask[:, 0] diff --git a/pandas/tests/groupby/test_missing.py b/pandas/tests/groupby/test_missing.py index f3149abb52291..525bba984fca5 100644 --- a/pandas/tests/groupby/test_missing.py +++ b/pandas/tests/groupby/test_missing.py @@ -130,6 +130,8 @@ def test_ffill_handles_nan_groups(dropna, method, has_nan_group): ridx = expected_rows.get((method, dropna, has_nan_group)) expected = df_without_nan_rows.reindex(ridx).reset_index(drop=True) + # columns are a 'take' on df.columns, which are object dtype + expected.columns = expected.columns.astype(object) tm.assert_frame_equal(result, expected)