Skip to content

REF: use WrappedCythonOp for GroupBy.std, sem #52053

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pandas/_libs/groupby.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def group_var(
mask: np.ndarray | None = ...,
result_mask: np.ndarray | None = ...,
is_datetimelike: bool = ...,
name: str = ...,
) -> None: ...
def group_mean(
out: np.ndarray, # floating[:, ::1]
Expand Down
12 changes: 11 additions & 1 deletion pandas/_libs/groupby.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ from cython cimport (
Py_ssize_t,
floating,
)
from libc.math cimport sqrt
from libc.stdlib cimport (
free,
malloc,
Expand Down Expand Up @@ -822,6 +823,7 @@ def group_var(
const uint8_t[:, ::1] mask=None,
uint8_t[:, ::1] result_mask=None,
bint is_datetimelike=False,
str name="var",
) -> None:
cdef:
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
Expand All @@ -830,6 +832,8 @@ def group_var(
int64_t[:, ::1] nobs
Py_ssize_t len_values = len(values), len_labels = len(labels)
bint isna_entry, uses_mask = mask is not None
bint is_std = name == "std"
bint is_sem = name == "sem"

assert min_count == -1, "'min_count' only used in sum and prod"

Expand Down Expand Up @@ -879,7 +883,13 @@ def group_var(
else:
out[i, j] = NAN
else:
out[i, j] /= (ct - ddof)
if is_std:
out[i, j] = sqrt(out[i, j] / (ct - ddof))
elif is_sem:
out[i, j] = sqrt(out[i, j] / (ct - ddof) / ct)
else:
# just "var"
out[i, j] /= (ct - ddof)


@cython.wraparound(False)
Expand Down
77 changes: 14 additions & 63 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,8 @@ class providing the base-class of operations.
BaseMaskedArray,
BooleanArray,
Categorical,
DatetimeArray,
ExtensionArray,
FloatingArray,
TimedeltaArray,
)
from pandas.core.base import (
PandasObject,
Expand Down Expand Up @@ -1979,30 +1977,12 @@ def std(

return np.sqrt(self._numba_agg_general(sliding_var, engine_kwargs, ddof))
else:

def _preprocessing(values):
if isinstance(values, BaseMaskedArray):
return values._data, None
return values, None

def _postprocessing(vals, inference, result_mask=None) -> ArrayLike:
if result_mask is not None:
if result_mask.ndim == 2:
result_mask = result_mask[:, 0]
return FloatingArray(np.sqrt(vals), result_mask.view(np.bool_))
return np.sqrt(vals)

result = self._get_cythonized_result(
libgroupby.group_var,
cython_dtype=np.dtype(np.float64),
return self._cython_agg_general(
"std",
alt=lambda x: Series(x).std(ddof=ddof),
numeric_only=numeric_only,
needs_counts=True,
pre_processing=_preprocessing,
post_processing=_postprocessing,
ddof=ddof,
how="std",
)
return result

@final
@Substitution(name="groupby")
Expand Down Expand Up @@ -2231,18 +2211,12 @@ def sem(self, ddof: int = 1, numeric_only: bool = False):
f"{type(self).__name__}.sem called with "
f"numeric_only={numeric_only} and dtype {self.obj.dtype}"
)
result = self.std(ddof=ddof, numeric_only=numeric_only)

if result.ndim == 1:
result /= np.sqrt(self.count())
else:
cols = result.columns.difference(self.exclusions).unique()
counts = self.count()
result_ilocs = result.columns.get_indexer_for(cols)
count_ilocs = counts.columns.get_indexer_for(cols)

result.iloc[:, result_ilocs] /= np.sqrt(counts.iloc[:, count_ilocs])
return result
return self._cython_agg_general(
"sem",
alt=lambda x: Series(x).sem(ddof=ddof),
numeric_only=numeric_only,
ddof=ddof,
)

@final
@Substitution(name="groupby")
Expand Down Expand Up @@ -3720,7 +3694,6 @@ def _get_cythonized_result(
base_func: Callable,
cython_dtype: np.dtype,
numeric_only: bool = False,
needs_counts: bool = False,
pre_processing=None,
post_processing=None,
how: str = "any_all",
Expand All @@ -3736,8 +3709,6 @@ def _get_cythonized_result(
Type of the array that will be modified by the Cython call.
numeric_only : bool, default False
Whether only numeric datatypes should be computed
needs_counts : bool, default False
Whether the counts should be a part of the Cython call
pre_processing : function, default None
Function to be applied to `values` prior to passing to Cython.
Function should return a tuple where the first element is the
Expand Down Expand Up @@ -3784,14 +3755,7 @@ def blk_func(values: ArrayLike) -> ArrayLike:

inferences = None

if needs_counts:
counts = np.zeros(ngroups, dtype=np.int64)
func = partial(func, counts=counts)

is_datetimelike = values.dtype.kind in ["m", "M"]
vals = values
if is_datetimelike and how == "std":
vals = vals.view("i8")
if pre_processing:
vals, inferences = pre_processing(vals)

Expand All @@ -3800,11 +3764,10 @@ def blk_func(values: ArrayLike) -> ArrayLike:
vals = vals.reshape((-1, 1))
func = partial(func, values=vals)

if how != "std" or isinstance(values, BaseMaskedArray):
mask = isna(values).view(np.uint8)
if mask.ndim == 1:
mask = mask.reshape(-1, 1)
func = partial(func, mask=mask)
mask = isna(values).view(np.uint8)
if mask.ndim == 1:
mask = mask.reshape(-1, 1)
func = partial(func, mask=mask)

result_mask = None
if isinstance(values, BaseMaskedArray):
Expand All @@ -3813,10 +3776,7 @@ def blk_func(values: ArrayLike) -> ArrayLike:
func = partial(func, result_mask=result_mask)

# Call func to modify result in place
if how == "std":
func(**kwargs, is_datetimelike=is_datetimelike)
else:
func(**kwargs)
func(**kwargs)

if values.ndim == 1:
assert result.shape[1] == 1, result.shape
Expand All @@ -3828,15 +3788,6 @@ def blk_func(values: ArrayLike) -> ArrayLike:
if post_processing:
result = post_processing(result, inferences, result_mask=result_mask)

if how == "std" and is_datetimelike:
values = cast("DatetimeArray | TimedeltaArray", values)
unit = values.unit
with warnings.catch_warnings():
# suppress "RuntimeWarning: invalid value encountered in cast"
warnings.filterwarnings("ignore")
result = result.astype(np.int64, copy=False)
result = result.view(f"m8[{unit}]")

return result.T

# Operate block-wise instead of column-by-column
Expand Down
27 changes: 22 additions & 5 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __init__(self, kind: str, how: str, has_dropped_na: bool) -> None:
self.how = how
self.has_dropped_na = has_dropped_na

_CYTHON_FUNCTIONS = {
_CYTHON_FUNCTIONS: dict[str, dict] = {
"aggregate": {
"sum": "group_sum",
"prod": "group_prod",
Expand All @@ -131,6 +131,8 @@ def __init__(self, kind: str, how: str, has_dropped_na: bool) -> None:
"mean": "group_mean",
"median": "group_median_float64",
"var": "group_var",
"std": functools.partial(libgroupby.group_var, name="std"),
"sem": functools.partial(libgroupby.group_var, name="sem"),
"first": "group_nth",
"last": "group_last",
"ohlc": "group_ohlc",
Expand Down Expand Up @@ -158,7 +160,10 @@ def _get_cython_function(

# see if there is a fused-type version of function
# only valid for numeric
f = getattr(libgroupby, ftype)
if callable(ftype):
f = ftype
else:
f = getattr(libgroupby, ftype)
if is_numeric:
return f
elif dtype == np.dtype(object):
Expand All @@ -168,6 +173,9 @@ def _get_cython_function(
f"function is not implemented for this dtype: "
f"[how->{how},dtype->{dtype_str}]"
)
elif how in ["std", "sem"]:
# We have a partial object that does not have __signatures__
return f
if "object" not in f.__signatures__:
# raise NotImplementedError here rather than TypeError later
raise NotImplementedError(
Expand Down Expand Up @@ -196,7 +204,7 @@ def _get_cython_vals(self, values: np.ndarray) -> np.ndarray:
"""
how = self.how

if how == "median":
if how in ["median", "std", "sem"]:
# median only has a float64 implementation
# We should only get here with is_numeric, as non-numeric cases
# should raise in _get_cython_function
Expand Down Expand Up @@ -314,7 +322,7 @@ def _get_result_dtype(self, dtype: np.dtype) -> np.dtype:
if how in ["sum", "cumsum", "sum", "prod", "cumprod"]:
if dtype == np.dtype(bool):
return np.dtype(np.int64)
elif how in ["mean", "median", "var"]:
elif how in ["mean", "median", "var", "std", "sem"]:
if is_float_dtype(dtype) or is_complex_dtype(dtype):
return dtype
elif is_numeric_dtype(dtype):
Expand Down Expand Up @@ -413,6 +421,13 @@ def _reconstruct_ea_result(
elif isinstance(values, (DatetimeArray, TimedeltaArray, PeriodArray)):
# In to_cython_values we took a view as M8[ns]
assert res_values.dtype == "M8[ns]"
if self.how in ["std", "sem"]:
if isinstance(values, PeriodArray):
raise TypeError("'std' and 'sem' are not valid for PeriodDtype")
new_dtype = f"m8[{values.unit}]"
res_values = res_values.view(new_dtype)
return TimedeltaArray(res_values)

res_values = res_values.view(values._ndarray.dtype)
return values._from_backing_data(res_values)

Expand Down Expand Up @@ -556,7 +571,9 @@ def _call_cython_op(
result_mask=result_mask,
is_datetimelike=is_datetimelike,
)
elif self.how in ["var", "ohlc", "prod", "median"]:
elif self.how in ["sem", "std", "var", "ohlc", "prod", "median"]:
if self.how in ["std", "sem"]:
kwargs["is_datetimelike"] = is_datetimelike
func(
result,
counts,
Expand Down
20 changes: 16 additions & 4 deletions pandas/tests/groupby/test_raises.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,14 +399,20 @@ def test_groupby_raises_category(
"prod": (TypeError, "category type does not support prod operations"),
"quantile": (TypeError, "No matching signature found"),
"rank": (None, ""),
"sem": (ValueError, "Cannot cast object dtype to float64"),
"sem": (
TypeError,
"'Categorical' with dtype category does not support reduction 'sem'",
),
"shift": (None, ""),
"size": (None, ""),
"skew": (
TypeError,
"'Categorical' with dtype category does not support reduction 'skew'",
),
"std": (ValueError, "Cannot cast object dtype to float64"),
"std": (
TypeError,
"'Categorical' with dtype category does not support reduction 'std'",
),
"sum": (TypeError, "category type does not support sum operations"),
"var": (
TypeError,
Expand Down Expand Up @@ -594,14 +600,20 @@ def test_groupby_raises_category_on_category(
"prod": (TypeError, "category type does not support prod operations"),
"quantile": (TypeError, ""),
"rank": (None, ""),
"sem": (ValueError, "Cannot cast object dtype to float64"),
"sem": (
TypeError,
"'Categorical' with dtype category does not support reduction 'sem'",
),
"shift": (None, ""),
"size": (None, ""),
"skew": (
TypeError,
"'Categorical' with dtype category does not support reduction 'skew'",
),
"std": (ValueError, "Cannot cast object dtype to float64"),
"std": (
TypeError,
"'Categorical' with dtype category does not support reduction 'std'",
),
"sum": (TypeError, "category type does not support sum operations"),
"var": (
TypeError,
Expand Down