Skip to content

Commit 84aca21

Browse files
authored
PERF: Don't sort labels in groupby.ffill/bfill (#56902)
* PERF: Don't sort labels in groupby.ffill/bfill * PR# * fixup
1 parent 7cd8ae5 commit 84aca21

File tree

4 files changed

+39
-39
lines changed

4 files changed

+39
-39
lines changed

doc/source/whatsnew/v2.3.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ Performance improvements
103103
~~~~~~~~~~~~~~~~~~~~~~~~
104104
- Performance improvement in :meth:`DataFrame.join` for sorted but non-unique indexes (:issue:`56941`)
105105
- Performance improvement in :meth:`DataFrame.join` when left and/or right are non-unique and ``how`` is ``"left"``, ``"right"``, or ``"inner"`` (:issue:`56817`)
106+
- Performance improvement in :meth:`DataFrameGroupBy.ffill`, :meth:`DataFrameGroupBy.bfill`, :meth:`SeriesGroupBy.ffill`, and :meth:`SeriesGroupBy.bfill` (:issue:`56902`)
106107
- Performance improvement in :meth:`Index.take` when ``indices`` is a full range indexer from zero to length of index (:issue:`56806`)
107108
-
108109

pandas/_libs/groupby.pyi

+2-2
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ def group_shift_indexer(
4242
def group_fillna_indexer(
4343
out: np.ndarray, # ndarray[intp_t]
4444
labels: np.ndarray, # ndarray[int64_t]
45-
sorted_labels: npt.NDArray[np.intp],
4645
mask: npt.NDArray[np.uint8],
4746
limit: int, # int64_t
48-
dropna: bool,
47+
compute_ffill: bool,
48+
ngroups: int,
4949
) -> None: ...
5050
def group_any_all(
5151
out: np.ndarray, # uint8_t[::1]

pandas/_libs/groupby.pyx

+33-31
Original file line numberDiff line numberDiff line change
@@ -493,10 +493,10 @@ def group_shift_indexer(
493493
def group_fillna_indexer(
494494
ndarray[intp_t] out,
495495
ndarray[intp_t] labels,
496-
ndarray[intp_t] sorted_labels,
497496
ndarray[uint8_t] mask,
498497
int64_t limit,
499-
bint dropna,
498+
bint compute_ffill,
499+
int ngroups,
500500
) -> None:
501501
"""
502502
Indexes how to fill values forwards or backwards within a group.
@@ -508,50 +508,52 @@ def group_fillna_indexer(
508508
labels : np.ndarray[np.intp]
509509
Array containing unique label for each group, with its ordering
510510
matching up to the corresponding record in `values`.
511-
sorted_labels : np.ndarray[np.intp]
512-
obtained by `np.argsort(labels, kind="mergesort")`
513-
values : np.ndarray[np.uint8]
514-
Containing the truth value of each element.
515511
mask : np.ndarray[np.uint8]
516512
Indicating whether a value is na or not.
517-
limit : Consecutive values to fill before stopping, or -1 for no limit
518-
dropna : Flag to indicate if NaN groups should return all NaN values
513+
limit : int64_t
514+
Consecutive values to fill before stopping, or -1 for no limit.
515+
compute_ffill : bint
516+
Whether to compute ffill or bfill.
517+
ngroups : int
518+
Number of groups, larger than all entries of `labels`.
519519

520520
Notes
521521
-----
522522
This method modifies the `out` parameter rather than returning an object
523523
"""
524524
cdef:
525-
Py_ssize_t i, N, idx
526-
intp_t curr_fill_idx=-1
527-
int64_t filled_vals = 0
528-
529-
N = len(out)
525+
Py_ssize_t idx, N = len(out)
526+
intp_t label
527+
intp_t[::1] last = -1 * np.ones(ngroups, dtype=np.intp)
528+
intp_t[::1] fill_count = np.zeros(ngroups, dtype=np.intp)
530529

531530
# Make sure all arrays are the same size
532531
assert N == len(labels) == len(mask)
533532

534533
with nogil:
535-
for i in range(N):
536-
idx = sorted_labels[i]
537-
if dropna and labels[idx] == -1: # nan-group gets nan-values
538-
curr_fill_idx = -1
534+
# Can't use for loop with +/- step
535+
# https://github.com/cython/cython/issues/1106
536+
idx = 0 if compute_ffill else N-1
537+
for _ in range(N):
538+
label = labels[idx]
539+
if label == -1: # na-group gets na-values
540+
out[idx] = -1
539541
elif mask[idx] == 1: # is missing
540542
# Stop filling once we've hit the limit
541-
if filled_vals >= limit and limit != -1:
542-
curr_fill_idx = -1
543-
filled_vals += 1
544-
else: # reset items when not missing
545-
filled_vals = 0
546-
curr_fill_idx = idx
547-
548-
out[idx] = curr_fill_idx
549-
550-
# If we move to the next group, reset
551-
# the fill_idx and counter
552-
if i == N - 1 or labels[idx] != labels[sorted_labels[i + 1]]:
553-
curr_fill_idx = -1
554-
filled_vals = 0
543+
if limit != -1 and fill_count[label] >= limit:
544+
out[idx] = -1
545+
else:
546+
out[idx] = last[label]
547+
fill_count[label] += 1
548+
else:
549+
fill_count[label] = 0 # reset items when not missing
550+
last[label] = idx
551+
out[idx] = idx
552+
553+
if compute_ffill:
554+
idx += 1
555+
else:
556+
idx -= 1
555557

556558

557559
@cython.boundscheck(False)

pandas/core/groupby/groupby.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -3958,17 +3958,14 @@ def _fill(self, direction: Literal["ffill", "bfill"], limit: int | None = None):
39583958
if limit is None:
39593959
limit = -1
39603960

3961-
ids, _, _ = self._grouper.group_info
3962-
sorted_labels = np.argsort(ids, kind="mergesort").astype(np.intp, copy=False)
3963-
if direction == "bfill":
3964-
sorted_labels = sorted_labels[::-1]
3961+
ids, _, ngroups = self._grouper.group_info
39653962

39663963
col_func = partial(
39673964
libgroupby.group_fillna_indexer,
39683965
labels=ids,
3969-
sorted_labels=sorted_labels,
39703966
limit=limit,
3971-
dropna=self.dropna,
3967+
compute_ffill=(direction == "ffill"),
3968+
ngroups=ngroups,
39723969
)
39733970

39743971
def blk_func(values: ArrayLike) -> ArrayLike:

0 commit comments

Comments
 (0)