Skip to content

Commit 22aa73c

Browse files
authored
REF: Groupby.pad/backfill operate blockwise (#43478)
1 parent 8d664c5 commit 22aa73c

File tree

4 files changed

+53
-21
lines changed

4 files changed

+53
-21
lines changed

pandas/_libs/groupby.pyi

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def group_shift_indexer(
3232
periods: int,
3333
) -> None: ...
3434
def group_fillna_indexer(
35-
out: np.ndarray, # ndarray[int64_t]
35+
out: np.ndarray, # ndarray[intp_t]
3636
labels: np.ndarray, # ndarray[int64_t]
3737
mask: np.ndarray, # ndarray[uint8_t]
3838
direction: Literal["ffill", "bfill"],

pandas/_libs/groupby.pyx

+2-2
Original file line numberDiff line numberDiff line change
@@ -321,15 +321,15 @@ def group_shift_indexer(int64_t[::1] out, const intp_t[::1] labels,
321321

322322
@cython.wraparound(False)
323323
@cython.boundscheck(False)
324-
def group_fillna_indexer(ndarray[int64_t] out, ndarray[intp_t] labels,
324+
def group_fillna_indexer(ndarray[intp_t] out, ndarray[intp_t] labels,
325325
ndarray[uint8_t] mask, str direction,
326326
int64_t limit, bint dropna) -> None:
327327
"""
328328
Indexes how to fill values forwards or backwards within a group.
329329

330330
Parameters
331331
----------
332-
out : np.ndarray[np.int64]
332+
out : np.ndarray[np.intp]
333333
Values into which this method will write its results.
334334
labels : np.ndarray[np.intp]
335335
Array containing unique label for each group, with its ordering

pandas/core/groupby/groupby.py

+48-18
Original file line numberDiff line numberDiff line change
@@ -1174,7 +1174,10 @@ def _wrap_transformed_output(
11741174
Series or DataFrame
11751175
Series for SeriesGroupBy, DataFrame for DataFrameGroupBy
11761176
"""
1177-
result = self._indexed_output_to_ndframe(output)
1177+
if isinstance(output, (Series, DataFrame)):
1178+
result = output
1179+
else:
1180+
result = self._indexed_output_to_ndframe(output)
11781181

11791182
if self.axis == 1:
11801183
# Only relevant for DataFrameGroupBy
@@ -2254,17 +2257,55 @@ def _fill(self, direction: Literal["ffill", "bfill"], limit=None):
22542257
if limit is None:
22552258
limit = -1
22562259

2257-
return self._get_cythonized_result(
2260+
ids, _, _ = self.grouper.group_info
2261+
2262+
col_func = partial(
22582263
libgroupby.group_fillna_indexer,
2259-
numeric_only=False,
2260-
needs_mask=True,
2261-
cython_dtype=np.dtype(np.int64),
2262-
result_is_index=True,
2264+
labels=ids,
22632265
direction=direction,
22642266
limit=limit,
22652267
dropna=self.dropna,
22662268
)
22672269

2270+
def blk_func(values: ArrayLike) -> ArrayLike:
2271+
mask = isna(values)
2272+
if values.ndim == 1:
2273+
indexer = np.empty(values.shape, dtype=np.intp)
2274+
col_func(out=indexer, mask=mask)
2275+
return algorithms.take_nd(values, indexer)
2276+
2277+
else:
2278+
# We broadcast algorithms.take_nd analogous to
2279+
# np.take_along_axis
2280+
2281+
# Note: we only get here with backfill/pad,
2282+
# so if we have a dtype that cannot hold NAs,
2283+
# then there will be no -1s in indexer, so we can use
2284+
# the original dtype (no need to ensure_dtype_can_hold_na)
2285+
if isinstance(values, np.ndarray):
2286+
out = np.empty(values.shape, dtype=values.dtype)
2287+
else:
2288+
out = type(values)._empty(values.shape, dtype=values.dtype)
2289+
2290+
for i in range(len(values)):
2291+
# call group_fillna_indexer column-wise
2292+
indexer = np.empty(values.shape[1], dtype=np.intp)
2293+
col_func(out=indexer, mask=mask[i])
2294+
out[i, :] = algorithms.take_nd(values[i], indexer)
2295+
return out
2296+
2297+
obj = self._obj_with_exclusions
2298+
if self.axis == 1:
2299+
obj = obj.T
2300+
mgr = obj._mgr
2301+
res_mgr = mgr.apply(blk_func)
2302+
2303+
new_obj = obj._constructor(res_mgr)
2304+
if isinstance(new_obj, Series):
2305+
new_obj.name = obj.name
2306+
2307+
return self._wrap_transformed_output(new_obj)
2308+
22682309
@final
22692310
@Substitution(name="groupby")
22702311
def pad(self, limit=None):
@@ -2944,7 +2985,6 @@ def _get_cythonized_result(
29442985
min_count: int | None = None,
29452986
needs_mask: bool = False,
29462987
needs_ngroups: bool = False,
2947-
result_is_index: bool = False,
29482988
pre_processing=None,
29492989
post_processing=None,
29502990
fill_value=None,
@@ -2981,9 +3021,6 @@ def _get_cythonized_result(
29813021
needs_nullable : bool, default False
29823022
Whether a bool specifying if the input is nullable is part
29833023
of the Cython call signature
2984-
result_is_index : bool, default False
2985-
Whether the result of the Cython operation is an index of
2986-
values to be retrieved, instead of the actual values themselves
29873024
pre_processing : function, default None
29883025
Function to be applied to `values` prior to passing to Cython.
29893026
Function should return a tuple where the first element is the
@@ -3009,8 +3046,6 @@ def _get_cythonized_result(
30093046
"""
30103047
numeric_only = self._resolve_numeric_only(numeric_only)
30113048

3012-
if result_is_index and aggregate:
3013-
raise ValueError("'result_is_index' and 'aggregate' cannot both be True!")
30143049
if post_processing and not callable(post_processing):
30153050
raise ValueError("'post_processing' must be a callable!")
30163051
if pre_processing:
@@ -3082,14 +3117,9 @@ def blk_func(values: ArrayLike) -> ArrayLike:
30823117

30833118
func(**kwargs) # Call func to modify indexer values in place
30843119

3085-
if result_is_index:
3086-
result = algorithms.take_nd(values, result, fill_value=fill_value)
3087-
30883120
if real_2d and values.ndim == 1:
30893121
assert result.shape[1] == 1, result.shape
3090-
# error: No overload variant of "__getitem__" of "ExtensionArray"
3091-
# matches argument type "Tuple[slice, int]"
3092-
result = result[:, 0] # type: ignore[call-overload]
3122+
result = result[:, 0]
30933123
if needs_mask:
30943124
mask = mask[:, 0]
30953125

pandas/tests/groupby/test_missing.py

+2
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ def test_ffill_handles_nan_groups(dropna, method, has_nan_group):
130130

131131
ridx = expected_rows.get((method, dropna, has_nan_group))
132132
expected = df_without_nan_rows.reindex(ridx).reset_index(drop=True)
133+
# columns are a 'take' on df.columns, which are object dtype
134+
expected.columns = expected.columns.astype(object)
133135

134136
tm.assert_frame_equal(result, expected)
135137

0 commit comments

Comments
 (0)