Skip to content

Commit 33a3fb1

Browse files
authored
REF: use WrappedCythonOp for GroupBy.std, sem (#52053)
* REF: use WrappedCythonOp for GroupBy.std, sem * mypy fixup
1 parent 8023225 commit 33a3fb1

File tree

5 files changed

+64
-73
lines changed

5 files changed

+64
-73
lines changed

pandas/_libs/groupby.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def group_var(
8686
mask: np.ndarray | None = ...,
8787
result_mask: np.ndarray | None = ...,
8888
is_datetimelike: bool = ...,
89+
name: str = ...,
8990
) -> None: ...
9091
def group_mean(
9192
out: np.ndarray, # floating[:, ::1]

pandas/_libs/groupby.pyx

+11-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ from cython cimport (
33
Py_ssize_t,
44
floating,
55
)
6+
from libc.math cimport sqrt
67
from libc.stdlib cimport (
78
free,
89
malloc,
@@ -822,6 +823,7 @@ def group_var(
822823
const uint8_t[:, ::1] mask=None,
823824
uint8_t[:, ::1] result_mask=None,
824825
bint is_datetimelike=False,
826+
str name="var",
825827
) -> None:
826828
cdef:
827829
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
@@ -830,6 +832,8 @@ def group_var(
830832
int64_t[:, ::1] nobs
831833
Py_ssize_t len_values = len(values), len_labels = len(labels)
832834
bint isna_entry, uses_mask = mask is not None
835+
bint is_std = name == "std"
836+
bint is_sem = name == "sem"
833837

834838
assert min_count == -1, "'min_count' only used in sum and prod"
835839

@@ -879,7 +883,13 @@ def group_var(
879883
else:
880884
out[i, j] = NAN
881885
else:
882-
out[i, j] /= (ct - ddof)
886+
if is_std:
887+
out[i, j] = sqrt(out[i, j] / (ct - ddof))
888+
elif is_sem:
889+
out[i, j] = sqrt(out[i, j] / (ct - ddof) / ct)
890+
else:
891+
# just "var"
892+
out[i, j] /= (ct - ddof)
883893

884894

885895
@cython.wraparound(False)

pandas/core/groupby/groupby.py

+14-63
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,8 @@ class providing the base-class of operations.
9797
BaseMaskedArray,
9898
BooleanArray,
9999
Categorical,
100-
DatetimeArray,
101100
ExtensionArray,
102101
FloatingArray,
103-
TimedeltaArray,
104102
)
105103
from pandas.core.base import (
106104
PandasObject,
@@ -1993,30 +1991,12 @@ def std(
19931991

19941992
return np.sqrt(self._numba_agg_general(sliding_var, engine_kwargs, ddof))
19951993
else:
1996-
1997-
def _preprocessing(values):
1998-
if isinstance(values, BaseMaskedArray):
1999-
return values._data, None
2000-
return values, None
2001-
2002-
def _postprocessing(vals, inference, result_mask=None) -> ArrayLike:
2003-
if result_mask is not None:
2004-
if result_mask.ndim == 2:
2005-
result_mask = result_mask[:, 0]
2006-
return FloatingArray(np.sqrt(vals), result_mask.view(np.bool_))
2007-
return np.sqrt(vals)
2008-
2009-
result = self._get_cythonized_result(
2010-
libgroupby.group_var,
2011-
cython_dtype=np.dtype(np.float64),
1994+
return self._cython_agg_general(
1995+
"std",
1996+
alt=lambda x: Series(x).std(ddof=ddof),
20121997
numeric_only=numeric_only,
2013-
needs_counts=True,
2014-
pre_processing=_preprocessing,
2015-
post_processing=_postprocessing,
20161998
ddof=ddof,
2017-
how="std",
20181999
)
2019-
return result
20202000

20212001
@final
20222002
@Substitution(name="groupby")
@@ -2245,18 +2225,12 @@ def sem(self, ddof: int = 1, numeric_only: bool = False):
22452225
f"{type(self).__name__}.sem called with "
22462226
f"numeric_only={numeric_only} and dtype {self.obj.dtype}"
22472227
)
2248-
result = self.std(ddof=ddof, numeric_only=numeric_only)
2249-
2250-
if result.ndim == 1:
2251-
result /= np.sqrt(self.count())
2252-
else:
2253-
cols = result.columns.difference(self.exclusions).unique()
2254-
counts = self.count()
2255-
result_ilocs = result.columns.get_indexer_for(cols)
2256-
count_ilocs = counts.columns.get_indexer_for(cols)
2257-
2258-
result.iloc[:, result_ilocs] /= np.sqrt(counts.iloc[:, count_ilocs])
2259-
return result
2228+
return self._cython_agg_general(
2229+
"sem",
2230+
alt=lambda x: Series(x).sem(ddof=ddof),
2231+
numeric_only=numeric_only,
2232+
ddof=ddof,
2233+
)
22602234

22612235
@final
22622236
@Substitution(name="groupby")
@@ -3734,7 +3708,6 @@ def _get_cythonized_result(
37343708
base_func: Callable,
37353709
cython_dtype: np.dtype,
37363710
numeric_only: bool = False,
3737-
needs_counts: bool = False,
37383711
pre_processing=None,
37393712
post_processing=None,
37403713
how: str = "any_all",
@@ -3750,8 +3723,6 @@ def _get_cythonized_result(
37503723
Type of the array that will be modified by the Cython call.
37513724
numeric_only : bool, default False
37523725
Whether only numeric datatypes should be computed
3753-
needs_counts : bool, default False
3754-
Whether the counts should be a part of the Cython call
37553726
pre_processing : function, default None
37563727
Function to be applied to `values` prior to passing to Cython.
37573728
Function should return a tuple where the first element is the
@@ -3798,14 +3769,7 @@ def blk_func(values: ArrayLike) -> ArrayLike:
37983769

37993770
inferences = None
38003771

3801-
if needs_counts:
3802-
counts = np.zeros(ngroups, dtype=np.int64)
3803-
func = partial(func, counts=counts)
3804-
3805-
is_datetimelike = values.dtype.kind in ["m", "M"]
38063772
vals = values
3807-
if is_datetimelike and how == "std":
3808-
vals = vals.view("i8")
38093773
if pre_processing:
38103774
vals, inferences = pre_processing(vals)
38113775

@@ -3814,11 +3778,10 @@ def blk_func(values: ArrayLike) -> ArrayLike:
38143778
vals = vals.reshape((-1, 1))
38153779
func = partial(func, values=vals)
38163780

3817-
if how != "std" or isinstance(values, BaseMaskedArray):
3818-
mask = isna(values).view(np.uint8)
3819-
if mask.ndim == 1:
3820-
mask = mask.reshape(-1, 1)
3821-
func = partial(func, mask=mask)
3781+
mask = isna(values).view(np.uint8)
3782+
if mask.ndim == 1:
3783+
mask = mask.reshape(-1, 1)
3784+
func = partial(func, mask=mask)
38223785

38233786
result_mask = None
38243787
if isinstance(values, BaseMaskedArray):
@@ -3827,10 +3790,7 @@ def blk_func(values: ArrayLike) -> ArrayLike:
38273790
func = partial(func, result_mask=result_mask)
38283791

38293792
# Call func to modify result in place
3830-
if how == "std":
3831-
func(**kwargs, is_datetimelike=is_datetimelike)
3832-
else:
3833-
func(**kwargs)
3793+
func(**kwargs)
38343794

38353795
if values.ndim == 1:
38363796
assert result.shape[1] == 1, result.shape
@@ -3842,15 +3802,6 @@ def blk_func(values: ArrayLike) -> ArrayLike:
38423802
if post_processing:
38433803
result = post_processing(result, inferences, result_mask=result_mask)
38443804

3845-
if how == "std" and is_datetimelike:
3846-
values = cast("DatetimeArray | TimedeltaArray", values)
3847-
unit = values.unit
3848-
with warnings.catch_warnings():
3849-
# suppress "RuntimeWarning: invalid value encountered in cast"
3850-
warnings.filterwarnings("ignore")
3851-
result = result.astype(np.int64, copy=False)
3852-
result = result.view(f"m8[{unit}]")
3853-
38543805
return result.T
38553806

38563807
# Operate block-wise instead of column-by-column

pandas/core/groupby/ops.py

+22-5
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def __init__(self, kind: str, how: str, has_dropped_na: bool) -> None:
122122
self.how = how
123123
self.has_dropped_na = has_dropped_na
124124

125-
_CYTHON_FUNCTIONS = {
125+
_CYTHON_FUNCTIONS: dict[str, dict] = {
126126
"aggregate": {
127127
"sum": "group_sum",
128128
"prod": "group_prod",
@@ -131,6 +131,8 @@ def __init__(self, kind: str, how: str, has_dropped_na: bool) -> None:
131131
"mean": "group_mean",
132132
"median": "group_median_float64",
133133
"var": "group_var",
134+
"std": functools.partial(libgroupby.group_var, name="std"),
135+
"sem": functools.partial(libgroupby.group_var, name="sem"),
134136
"first": "group_nth",
135137
"last": "group_last",
136138
"ohlc": "group_ohlc",
@@ -158,7 +160,10 @@ def _get_cython_function(
158160

159161
# see if there is a fused-type version of function
160162
# only valid for numeric
161-
f = getattr(libgroupby, ftype)
163+
if callable(ftype):
164+
f = ftype
165+
else:
166+
f = getattr(libgroupby, ftype)
162167
if is_numeric:
163168
return f
164169
elif dtype == np.dtype(object):
@@ -168,6 +173,9 @@ def _get_cython_function(
168173
f"function is not implemented for this dtype: "
169174
f"[how->{how},dtype->{dtype_str}]"
170175
)
176+
elif how in ["std", "sem"]:
177+
# We have a partial object that does not have __signatures__
178+
return f
171179
if "object" not in f.__signatures__:
172180
# raise NotImplementedError here rather than TypeError later
173181
raise NotImplementedError(
@@ -196,7 +204,7 @@ def _get_cython_vals(self, values: np.ndarray) -> np.ndarray:
196204
"""
197205
how = self.how
198206

199-
if how == "median":
207+
if how in ["median", "std", "sem"]:
200208
# median only has a float64 implementation
201209
# We should only get here with is_numeric, as non-numeric cases
202210
# should raise in _get_cython_function
@@ -314,7 +322,7 @@ def _get_result_dtype(self, dtype: np.dtype) -> np.dtype:
314322
if how in ["sum", "cumsum", "sum", "prod", "cumprod"]:
315323
if dtype == np.dtype(bool):
316324
return np.dtype(np.int64)
317-
elif how in ["mean", "median", "var"]:
325+
elif how in ["mean", "median", "var", "std", "sem"]:
318326
if is_float_dtype(dtype) or is_complex_dtype(dtype):
319327
return dtype
320328
elif is_numeric_dtype(dtype):
@@ -413,6 +421,13 @@ def _reconstruct_ea_result(
413421
elif isinstance(values, (DatetimeArray, TimedeltaArray, PeriodArray)):
414422
# In to_cython_values we took a view as M8[ns]
415423
assert res_values.dtype == "M8[ns]"
424+
if self.how in ["std", "sem"]:
425+
if isinstance(values, PeriodArray):
426+
raise TypeError("'std' and 'sem' are not valid for PeriodDtype")
427+
new_dtype = f"m8[{values.unit}]"
428+
res_values = res_values.view(new_dtype)
429+
return TimedeltaArray(res_values)
430+
416431
res_values = res_values.view(values._ndarray.dtype)
417432
return values._from_backing_data(res_values)
418433

@@ -556,7 +571,9 @@ def _call_cython_op(
556571
result_mask=result_mask,
557572
is_datetimelike=is_datetimelike,
558573
)
559-
elif self.how in ["var", "ohlc", "prod", "median"]:
574+
elif self.how in ["sem", "std", "var", "ohlc", "prod", "median"]:
575+
if self.how in ["std", "sem"]:
576+
kwargs["is_datetimelike"] = is_datetimelike
560577
func(
561578
result,
562579
counts,

pandas/tests/groupby/test_raises.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -399,14 +399,20 @@ def test_groupby_raises_category(
399399
"prod": (TypeError, "category type does not support prod operations"),
400400
"quantile": (TypeError, "No matching signature found"),
401401
"rank": (None, ""),
402-
"sem": (ValueError, "Cannot cast object dtype to float64"),
402+
"sem": (
403+
TypeError,
404+
"'Categorical' with dtype category does not support reduction 'sem'",
405+
),
403406
"shift": (None, ""),
404407
"size": (None, ""),
405408
"skew": (
406409
TypeError,
407410
"'Categorical' with dtype category does not support reduction 'skew'",
408411
),
409-
"std": (ValueError, "Cannot cast object dtype to float64"),
412+
"std": (
413+
TypeError,
414+
"'Categorical' with dtype category does not support reduction 'std'",
415+
),
410416
"sum": (TypeError, "category type does not support sum operations"),
411417
"var": (
412418
TypeError,
@@ -594,14 +600,20 @@ def test_groupby_raises_category_on_category(
594600
"prod": (TypeError, "category type does not support prod operations"),
595601
"quantile": (TypeError, ""),
596602
"rank": (None, ""),
597-
"sem": (ValueError, "Cannot cast object dtype to float64"),
603+
"sem": (
604+
TypeError,
605+
"'Categorical' with dtype category does not support reduction 'sem'",
606+
),
598607
"shift": (None, ""),
599608
"size": (None, ""),
600609
"skew": (
601610
TypeError,
602611
"'Categorical' with dtype category does not support reduction 'skew'",
603612
),
604-
"std": (ValueError, "Cannot cast object dtype to float64"),
613+
"std": (
614+
TypeError,
615+
"'Categorical' with dtype category does not support reduction 'std'",
616+
),
605617
"sum": (TypeError, "category type does not support sum operations"),
606618
"var": (
607619
TypeError,

0 commit comments

Comments
 (0)