Skip to content

Commit ecff20e

Browse files
authored
PERF: groupby.pad/backfill follow-up (#43518)
1 parent e0ef148 commit ecff20e

File tree

5 files changed

+32
-90
lines changed

5 files changed

+32
-90
lines changed

pandas/_libs/groupby.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def group_shift_indexer(
3636
def group_fillna_indexer(
3737
out: np.ndarray, # ndarray[intp_t]
3838
labels: np.ndarray, # ndarray[int64_t]
39+
sorted_labels: npt.NDArray[np.intp],
3940
mask: np.ndarray, # ndarray[uint8_t]
4041
direction: Literal["ffill", "bfill"],
4142
limit: int, # int64_t

pandas/_libs/groupby.pyx

+4-6
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,7 @@ def group_shift_indexer(int64_t[::1] out, const intp_t[::1] labels,
322322
@cython.wraparound(False)
323323
@cython.boundscheck(False)
324324
def group_fillna_indexer(ndarray[intp_t] out, ndarray[intp_t] labels,
325+
ndarray[intp_t] sorted_labels,
325326
ndarray[uint8_t] mask, str direction,
326327
int64_t limit, bint dropna) -> None:
327328
"""
@@ -334,6 +335,9 @@ def group_fillna_indexer(ndarray[intp_t] out, ndarray[intp_t] labels,
334335
labels : np.ndarray[np.intp]
335336
Array containing unique label for each group, with its ordering
336337
matching up to the corresponding record in `values`.
338+
sorted_labels : np.ndarray[np.intp]
339+
obtained by `np.argsort(labels, kind="mergesort")`; reversed if
340+
direction == "bfill"
337341
values : np.ndarray[np.uint8]
338342
Containing the truth value of each element.
339343
mask : np.ndarray[np.uint8]
@@ -349,7 +353,6 @@ def group_fillna_indexer(ndarray[intp_t] out, ndarray[intp_t] labels,
349353
"""
350354
cdef:
351355
Py_ssize_t i, N, idx
352-
intp_t[:] sorted_labels
353356
intp_t curr_fill_idx=-1
354357
int64_t filled_vals = 0
355358

@@ -358,11 +361,6 @@ def group_fillna_indexer(ndarray[intp_t] out, ndarray[intp_t] labels,
358361
# Make sure all arrays are the same size
359362
assert N == len(labels) == len(mask)
360363

361-
sorted_labels = np.argsort(labels, kind='mergesort').astype(
362-
np.intp, copy=False)
363-
if direction == 'bfill':
364-
sorted_labels = sorted_labels[::-1]
365-
366364
with nogil:
367365
for i in range(N):
368366
idx = sorted_labels[i]

pandas/core/groupby/generic.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1063,7 +1063,9 @@ def _aggregate_item_by_item(self, func, *args, **kwargs) -> DataFrame:
10631063
res_df.columns = obj.columns
10641064
return res_df
10651065

1066-
def _wrap_applied_output(self, data, values, not_indexed_same=False):
1066+
def _wrap_applied_output(
1067+
self, data: DataFrame, values: list, not_indexed_same: bool = False
1068+
):
10671069

10681070
if len(values) == 0:
10691071
result = self.obj._constructor(
@@ -1099,9 +1101,7 @@ def _wrap_applied_output(self, data, values, not_indexed_same=False):
10991101
if self.as_index:
11001102
return self.obj._constructor_sliced(values, index=key_index)
11011103
else:
1102-
result = self.obj._constructor(
1103-
values, index=key_index, columns=[self._selection]
1104-
)
1104+
result = self.obj._constructor(values, columns=[self._selection])
11051105
self._insert_inaxis_grouper_inplace(result)
11061106
return result
11071107
else:

pandas/core/groupby/groupby.py

+21-80
Original file line numberDiff line numberDiff line change
@@ -1187,7 +1187,7 @@ def _wrap_transformed_output(
11871187
result.index = self.obj.index
11881188
return result
11891189

1190-
def _wrap_applied_output(self, data, values, not_indexed_same: bool = False):
1190+
def _wrap_applied_output(self, data, values: list, not_indexed_same: bool = False):
11911191
raise AbstractMethodError(self)
11921192

11931193
def _resolve_numeric_only(self, numeric_only: bool | lib.NoDefault) -> bool:
@@ -1667,11 +1667,8 @@ def result_to_bool(
16671667

16681668
return self._get_cythonized_result(
16691669
libgroupby.group_any_all,
1670-
aggregate=True,
16711670
numeric_only=False,
16721671
cython_dtype=np.dtype(np.int8),
1673-
needs_values=True,
1674-
needs_2d=True,
16751672
needs_mask=True,
16761673
needs_nullable=True,
16771674
pre_processing=objs_to_bool,
@@ -1867,10 +1864,7 @@ def std(self, ddof: int = 1):
18671864
"""
18681865
return self._get_cythonized_result(
18691866
libgroupby.group_var,
1870-
aggregate=True,
18711867
needs_counts=True,
1872-
needs_values=True,
1873-
needs_2d=True,
18741868
cython_dtype=np.dtype(np.float64),
18751869
post_processing=lambda vals, inference: np.sqrt(vals),
18761870
ddof=ddof,
@@ -2281,10 +2275,14 @@ def _fill(self, direction: Literal["ffill", "bfill"], limit=None):
22812275
limit = -1
22822276

22832277
ids, _, _ = self.grouper.group_info
2278+
sorted_labels = np.argsort(ids, kind="mergesort").astype(np.intp, copy=False)
2279+
if direction == "bfill":
2280+
sorted_labels = sorted_labels[::-1]
22842281

22852282
col_func = partial(
22862283
libgroupby.group_fillna_indexer,
22872284
labels=ids,
2285+
sorted_labels=sorted_labels,
22882286
direction=direction,
22892287
limit=limit,
22902288
dropna=self.dropna,
@@ -3014,18 +3012,12 @@ def _get_cythonized_result(
30143012
self,
30153013
base_func: Callable,
30163014
cython_dtype: np.dtype,
3017-
aggregate: bool = False,
30183015
numeric_only: bool | lib.NoDefault = lib.no_default,
30193016
needs_counts: bool = False,
3020-
needs_values: bool = False,
3021-
needs_2d: bool = False,
30223017
needs_nullable: bool = False,
3023-
min_count: int | None = None,
30243018
needs_mask: bool = False,
3025-
needs_ngroups: bool = False,
30263019
pre_processing=None,
30273020
post_processing=None,
3028-
fill_value=None,
30293021
**kwargs,
30303022
):
30313023
"""
@@ -3036,26 +3028,13 @@ def _get_cythonized_result(
30363028
base_func : callable, Cythonized function to be called
30373029
cython_dtype : np.dtype
30383030
Type of the array that will be modified by the Cython call.
3039-
aggregate : bool, default False
3040-
Whether the result should be aggregated to match the number of
3041-
groups
30423031
numeric_only : bool, default True
30433032
Whether only numeric datatypes should be computed
30443033
needs_counts : bool, default False
30453034
Whether the counts should be a part of the Cython call
3046-
needs_values : bool, default False
3047-
Whether the values should be a part of the Cython call
3048-
signature
3049-
needs_2d : bool, default False
3050-
Whether the values and result of the Cython call signature
3051-
are 2-dimensional.
3052-
min_count : int, default None
3053-
When not None, min_count for the Cython call
30543035
needs_mask : bool, default False
30553036
Whether boolean mask needs to be part of the Cython call
30563037
signature
3057-
needs_ngroups : bool, default False
3058-
Whether number of groups is part of the Cython call signature
30593038
needs_nullable : bool, default False
30603039
Whether a bool specifying if the input is nullable is part
30613040
of the Cython call signature
@@ -3073,8 +3052,6 @@ def _get_cythonized_result(
30733052
second argument, i.e. the signature should be
30743053
(ndarray, Type). If `needs_nullable=True`, a third argument should be
30753054
`nullable`, to allow for processing specific to nullable values.
3076-
fill_value : any, default None
3077-
The scalar value to use for newly introduced missing values.
30783055
**kwargs : dict
30793056
Extra arguments to be passed back to Cython funcs
30803057
@@ -3086,13 +3063,8 @@ def _get_cythonized_result(
30863063

30873064
if post_processing and not callable(post_processing):
30883065
raise ValueError("'post_processing' must be a callable!")
3089-
if pre_processing:
3090-
if not callable(pre_processing):
3091-
raise ValueError("'pre_processing' must be a callable!")
3092-
if not needs_values:
3093-
raise ValueError(
3094-
"Cannot use 'pre_processing' without specifying 'needs_values'!"
3095-
)
3066+
if pre_processing and not callable(pre_processing):
3067+
raise ValueError("'pre_processing' must be a callable!")
30963068

30973069
grouper = self.grouper
30983070

@@ -3101,29 +3073,14 @@ def _get_cythonized_result(
31013073

31023074
how = base_func.__name__
31033075
base_func = partial(base_func, labels=ids)
3104-
if needs_ngroups:
3105-
base_func = partial(base_func, ngroups=ngroups)
3106-
if min_count is not None:
3107-
base_func = partial(base_func, min_count=min_count)
3108-
3109-
real_2d = how in ["group_any_all", "group_var"]
31103076

31113077
def blk_func(values: ArrayLike) -> ArrayLike:
31123078
values = values.T
31133079
ncols = 1 if values.ndim == 1 else values.shape[1]
31143080

3115-
if aggregate:
3116-
result_sz = ngroups
3117-
else:
3118-
result_sz = values.shape[-1]
3119-
31203081
result: ArrayLike
3121-
result = np.zeros(result_sz * ncols, dtype=cython_dtype)
3122-
if needs_2d:
3123-
if real_2d:
3124-
result = result.reshape((result_sz, ncols))
3125-
else:
3126-
result = result.reshape(-1, 1)
3082+
result = np.zeros(ngroups * ncols, dtype=cython_dtype)
3083+
result = result.reshape((ngroups, ncols))
31273084

31283085
func = partial(base_func, out=result)
31293086

@@ -3133,19 +3090,18 @@ def blk_func(values: ArrayLike) -> ArrayLike:
31333090
counts = np.zeros(self.ngroups, dtype=np.int64)
31343091
func = partial(func, counts=counts)
31353092

3136-
if needs_values:
3137-
vals = values
3138-
if pre_processing:
3139-
vals, inferences = pre_processing(vals)
3093+
vals = values
3094+
if pre_processing:
3095+
vals, inferences = pre_processing(vals)
31403096

3141-
vals = vals.astype(cython_dtype, copy=False)
3142-
if needs_2d and vals.ndim == 1:
3143-
vals = vals.reshape((-1, 1))
3144-
func = partial(func, values=vals)
3097+
vals = vals.astype(cython_dtype, copy=False)
3098+
if vals.ndim == 1:
3099+
vals = vals.reshape((-1, 1))
3100+
func = partial(func, values=vals)
31453101

31463102
if needs_mask:
31473103
mask = isna(values).view(np.uint8)
3148-
if needs_2d and mask.ndim == 1:
3104+
if mask.ndim == 1:
31493105
mask = mask.reshape(-1, 1)
31503106
func = partial(func, mask=mask)
31513107

@@ -3155,11 +3111,9 @@ def blk_func(values: ArrayLike) -> ArrayLike:
31553111

31563112
func(**kwargs) # Call func to modify indexer values in place
31573113

3158-
if real_2d and values.ndim == 1:
3114+
if values.ndim == 1:
31593115
assert result.shape[1] == 1, result.shape
31603116
result = result[:, 0]
3161-
if needs_mask:
3162-
mask = mask[:, 0]
31633117

31643118
if post_processing:
31653119
pp_kwargs = {}
@@ -3168,17 +3122,10 @@ def blk_func(values: ArrayLike) -> ArrayLike:
31683122

31693123
result = post_processing(result, inferences, **pp_kwargs)
31703124

3171-
if needs_2d and not real_2d:
3172-
if result.ndim == 2:
3173-
assert result.shape[1] == 1
3174-
# error: No overload variant of "__getitem__" of "ExtensionArray"
3175-
# matches argument type "Tuple[slice, int]"
3176-
result = result[:, 0] # type: ignore[call-overload]
3177-
31783125
return result.T
31793126

31803127
obj = self._obj_with_exclusions
3181-
if obj.ndim == 2 and self.axis == 0 and needs_2d and real_2d:
3128+
if obj.ndim == 2 and self.axis == 0:
31823129
# Operate block-wise instead of column-by-column
31833130
mgr = obj._mgr
31843131
if numeric_only:
@@ -3187,10 +3134,7 @@ def blk_func(values: ArrayLike) -> ArrayLike:
31873134
# setting ignore_failures=False for troubleshooting
31883135
res_mgr = mgr.grouped_reduce(blk_func, ignore_failures=False)
31893136
output = type(obj)(res_mgr)
3190-
if aggregate:
3191-
return self._wrap_aggregated_output(output)
3192-
else:
3193-
return self._wrap_transformed_output(output)
3137+
return self._wrap_aggregated_output(output)
31943138

31953139
error_msg = ""
31963140
for idx, obj in enumerate(self._iterate_slices()):
@@ -3222,10 +3166,7 @@ def blk_func(values: ArrayLike) -> ArrayLike:
32223166
if not output and error_msg != "":
32233167
raise TypeError(error_msg)
32243168

3225-
if aggregate:
3226-
return self._wrap_aggregated_output(output)
3227-
else:
3228-
return self._wrap_transformed_output(output)
3169+
return self._wrap_aggregated_output(output)
32293170

32303171
@final
32313172
@Substitution(name="groupby")

pandas/core/groupby/ops.py

+2
Original file line numberDiff line numberDiff line change
@@ -757,6 +757,7 @@ def indices(self):
757757
keys = [ping.group_index for ping in self.groupings]
758758
return get_indexer_dict(codes_list, keys)
759759

760+
@final
760761
@property
761762
def codes(self) -> list[np.ndarray]:
762763
return [ping.codes for ping in self.groupings]
@@ -836,6 +837,7 @@ def reconstructed_codes(self) -> list[np.ndarray]:
836837
ids, obs_ids, _ = self.group_info
837838
return decons_obs_group_ids(ids, obs_ids, self.shape, codes, xnull=True)
838839

840+
@final
839841
@cache_readonly
840842
def result_arraylike(self) -> ArrayLike:
841843
"""

0 commit comments

Comments
 (0)