Skip to content

Commit 5f37c47

Browse files
authored
REF: GroupBy.any/all use WrappedCythonOp (#52089)
1 parent 33a3fb1 commit 5f37c47

File tree

2 files changed

+47
-157
lines changed

2 files changed

+47
-157
lines changed

pandas/core/groupby/groupby.py

+12-152
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ class providing the base-class of operations.
9595
from pandas.core._numba import executor
9696
from pandas.core.arrays import (
9797
BaseMaskedArray,
98-
BooleanArray,
9998
Categorical,
10099
ExtensionArray,
101100
FloatingArray,
@@ -1545,6 +1544,8 @@ def array_func(values: ArrayLike) -> ArrayLike:
15451544
# and non-applicable functions
15461545
# try to python agg
15471546
# TODO: shouldn't min_count matter?
1547+
if how in ["any", "all"]:
1548+
raise # TODO: re-raise as TypeError?
15481549
result = self._agg_py_fallback(values, ndim=data.ndim, alt=alt)
15491550

15501551
return result
@@ -1694,45 +1695,6 @@ def _obj_1d_constructor(self) -> Callable:
16941695
assert isinstance(self.obj, Series)
16951696
return self.obj._constructor
16961697

1697-
@final
1698-
def _bool_agg(self, val_test: Literal["any", "all"], skipna: bool):
1699-
"""
1700-
Shared func to call any / all Cython GroupBy implementations.
1701-
"""
1702-
1703-
def objs_to_bool(vals: ArrayLike) -> tuple[np.ndarray, type]:
1704-
if is_object_dtype(vals.dtype) and skipna:
1705-
# GH#37501: don't raise on pd.NA when skipna=True
1706-
mask = isna(vals)
1707-
if mask.any():
1708-
# mask on original values computed separately
1709-
vals = vals.copy()
1710-
vals[mask] = True
1711-
elif isinstance(vals, BaseMaskedArray):
1712-
vals = vals._data
1713-
vals = vals.astype(bool, copy=False)
1714-
return vals.view(np.int8), bool
1715-
1716-
def result_to_bool(
1717-
result: np.ndarray,
1718-
inference: type,
1719-
result_mask,
1720-
) -> ArrayLike:
1721-
if result_mask is not None:
1722-
return BooleanArray(result.astype(bool, copy=False), result_mask)
1723-
else:
1724-
return result.astype(inference, copy=False)
1725-
1726-
return self._get_cythonized_result(
1727-
libgroupby.group_any_all,
1728-
numeric_only=False,
1729-
cython_dtype=np.dtype(np.int8),
1730-
pre_processing=objs_to_bool,
1731-
post_processing=result_to_bool,
1732-
val_test=val_test,
1733-
skipna=skipna,
1734-
)
1735-
17361698
@final
17371699
@Substitution(name="groupby")
17381700
@Appender(_common_see_also)
@@ -1751,7 +1713,11 @@ def any(self, skipna: bool = True):
17511713
DataFrame or Series of boolean values, where a value is True if any element
17521714
is True within its respective group, False otherwise.
17531715
"""
1754-
return self._bool_agg("any", skipna)
1716+
return self._cython_agg_general(
1717+
"any",
1718+
alt=lambda x: Series(x).any(skipna=skipna),
1719+
skipna=skipna,
1720+
)
17551721

17561722
@final
17571723
@Substitution(name="groupby")
@@ -1771,7 +1737,11 @@ def all(self, skipna: bool = True):
17711737
DataFrame or Series of boolean values, where a value is True if all elements
17721738
are True within its respective group, False otherwise.
17731739
"""
1774-
return self._bool_agg("all", skipna)
1740+
return self._cython_agg_general(
1741+
"all",
1742+
alt=lambda x: Series(x).all(skipna=skipna),
1743+
skipna=skipna,
1744+
)
17751745

17761746
@final
17771747
@Substitution(name="groupby")
@@ -3702,116 +3672,6 @@ def cummax(
37023672
"cummax", numeric_only=numeric_only, skipna=skipna
37033673
)
37043674

3705-
@final
3706-
def _get_cythonized_result(
3707-
self,
3708-
base_func: Callable,
3709-
cython_dtype: np.dtype,
3710-
numeric_only: bool = False,
3711-
pre_processing=None,
3712-
post_processing=None,
3713-
how: str = "any_all",
3714-
**kwargs,
3715-
):
3716-
"""
3717-
Get result for Cythonized functions.
3718-
3719-
Parameters
3720-
----------
3721-
base_func : callable, Cythonized function to be called
3722-
cython_dtype : np.dtype
3723-
Type of the array that will be modified by the Cython call.
3724-
numeric_only : bool, default False
3725-
Whether only numeric datatypes should be computed
3726-
pre_processing : function, default None
3727-
Function to be applied to `values` prior to passing to Cython.
3728-
Function should return a tuple where the first element is the
3729-
values to be passed to Cython and the second element is an optional
3730-
type which the values should be converted to after being returned
3731-
by the Cython operation. This function is also responsible for
3732-
raising a TypeError if the values have an invalid type. Raises
3733-
if `needs_values` is False.
3734-
post_processing : function, default None
3735-
Function to be applied to result of Cython function. Should accept
3736-
an array of values as the first argument and type inferences as its
3737-
second argument, i.e. the signature should be
3738-
(ndarray, Type). If `needs_nullable=True`, a third argument should be
3739-
`nullable`, to allow for processing specific to nullable values.
3740-
how : str, default any_all
3741-
Determines if any/all cython interface or std interface is used.
3742-
**kwargs : dict
3743-
Extra arguments to be passed back to Cython funcs
3744-
3745-
Returns
3746-
-------
3747-
`Series` or `DataFrame` with filled values
3748-
"""
3749-
if post_processing and not callable(post_processing):
3750-
raise ValueError("'post_processing' must be a callable!")
3751-
if pre_processing and not callable(pre_processing):
3752-
raise ValueError("'pre_processing' must be a callable!")
3753-
3754-
grouper = self.grouper
3755-
3756-
ids, _, ngroups = grouper.group_info
3757-
3758-
base_func = partial(base_func, labels=ids)
3759-
3760-
def blk_func(values: ArrayLike) -> ArrayLike:
3761-
values = values.T
3762-
ncols = 1 if values.ndim == 1 else values.shape[1]
3763-
3764-
result: ArrayLike
3765-
result = np.zeros(ngroups * ncols, dtype=cython_dtype)
3766-
result = result.reshape((ngroups, ncols))
3767-
3768-
func = partial(base_func, out=result)
3769-
3770-
inferences = None
3771-
3772-
vals = values
3773-
if pre_processing:
3774-
vals, inferences = pre_processing(vals)
3775-
3776-
vals = vals.astype(cython_dtype, copy=False)
3777-
if vals.ndim == 1:
3778-
vals = vals.reshape((-1, 1))
3779-
func = partial(func, values=vals)
3780-
3781-
mask = isna(values).view(np.uint8)
3782-
if mask.ndim == 1:
3783-
mask = mask.reshape(-1, 1)
3784-
func = partial(func, mask=mask)
3785-
3786-
result_mask = None
3787-
if isinstance(values, BaseMaskedArray):
3788-
result_mask = np.zeros(result.shape, dtype=np.bool_)
3789-
3790-
func = partial(func, result_mask=result_mask)
3791-
3792-
# Call func to modify result in place
3793-
func(**kwargs)
3794-
3795-
if values.ndim == 1:
3796-
assert result.shape[1] == 1, result.shape
3797-
result = result[:, 0]
3798-
if result_mask is not None:
3799-
assert result_mask.shape[1] == 1, result_mask.shape
3800-
result_mask = result_mask[:, 0]
3801-
3802-
if post_processing:
3803-
result = post_processing(result, inferences, result_mask=result_mask)
3804-
3805-
return result.T
3806-
3807-
# Operate block-wise instead of column-by-column
3808-
mgr = self._get_data_to_aggregate(numeric_only=numeric_only, name=how)
3809-
3810-
res_mgr = mgr.grouped_reduce(blk_func)
3811-
3812-
out = self._wrap_agged_manager(res_mgr)
3813-
return self._wrap_aggregated_output(out)
3814-
38153675
@final
38163676
@Substitution(name="groupby")
38173677
def shift(

pandas/core/groupby/ops.py

+35-5
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,9 @@ class WrappedCythonOp:
115115

116116
# Functions for which we do _not_ attempt to cast the cython result
117117
# back to the original dtype.
118-
cast_blocklist = frozenset(["rank", "count", "size", "idxmin", "idxmax"])
118+
cast_blocklist = frozenset(
119+
["any", "all", "rank", "count", "size", "idxmin", "idxmax"]
120+
)
119121

120122
def __init__(self, kind: str, how: str, has_dropped_na: bool) -> None:
121123
self.kind = kind
@@ -124,6 +126,8 @@ def __init__(self, kind: str, how: str, has_dropped_na: bool) -> None:
124126

125127
_CYTHON_FUNCTIONS: dict[str, dict] = {
126128
"aggregate": {
129+
"any": functools.partial(libgroupby.group_any_all, val_test="any"),
130+
"all": functools.partial(libgroupby.group_any_all, val_test="all"),
127131
"sum": "group_sum",
128132
"prod": "group_prod",
129133
"min": "group_min",
@@ -253,7 +257,7 @@ def _disallow_invalid_ops(self, dtype: DtypeObj, is_numeric: bool = False):
253257
# don't go down a group-by-group path, since in the empty-groups
254258
# case that would fail to raise
255259
raise TypeError(f"Cannot perform {how} with non-ordered Categorical")
256-
if how not in ["rank"]:
260+
if how not in ["rank", "any", "all"]:
257261
# only "rank" is implemented in cython
258262
raise NotImplementedError(f"{dtype} dtype not supported")
259263

@@ -352,10 +356,13 @@ def _ea_wrap_cython_operation(
352356
)
353357

354358
elif isinstance(values, Categorical):
355-
assert self.how == "rank" # the only one implemented ATM
356-
assert values.ordered # checked earlier
359+
assert self.how in ["rank", "any", "all"]
357360
mask = values.isna()
358-
npvalues = values._ndarray
361+
if self.how == "rank":
362+
assert values.ordered # checked earlier
363+
npvalues = values._ndarray
364+
else:
365+
npvalues = values.astype(bool)
359366

360367
res_values = self._cython_op_ndim_compat(
361368
npvalues,
@@ -546,6 +553,19 @@ def _call_cython_op(
546553
if values.dtype == "float16":
547554
values = values.astype(np.float32)
548555

556+
if self.how in ["any", "all"]:
557+
if mask is None:
558+
mask = isna(values)
559+
if dtype == object:
560+
if kwargs["skipna"]:
561+
# GH#37501: don't raise on pd.NA when skipna=True
562+
if mask.any():
563+
# mask on original values computed separately
564+
values = values.copy()
565+
values[mask] = True
566+
values = values.astype(bool, copy=False).view(np.int8)
567+
is_numeric = True
568+
549569
values = values.T
550570
if mask is not None:
551571
mask = mask.T
@@ -584,6 +604,16 @@ def _call_cython_op(
584604
result_mask=result_mask,
585605
**kwargs,
586606
)
607+
elif self.how in ["any", "all"]:
608+
func(
609+
out=result,
610+
values=values,
611+
labels=comp_ids,
612+
mask=mask,
613+
result_mask=result_mask,
614+
**kwargs,
615+
)
616+
result = result.astype(bool, copy=False)
587617
else:
588618
raise NotImplementedError(f"{self.how} is not implemented")
589619
else:

0 commit comments

Comments
 (0)