Skip to content

Commit 548444b

Browse files
authored
ENH: Preserve ea dtype in groupby.std (#50375)
* ENH: Preserve ea dtype in groupby.std * Add gh ref * Add typing * Add docstring
1 parent 3ea04c3 commit 548444b

File tree

3 files changed

+41
-17
lines changed

3 files changed

+41
-17
lines changed

doc/source/whatsnew/v2.0.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ be set to ``"pyarrow"`` to return pyarrow-backed, nullable :class:`ArrowDtype` (
8282
Other enhancements
8383
^^^^^^^^^^^^^^^^^^
8484
- :func:`read_sas` now supports using ``encoding='infer'`` to correctly read and use the encoding specified by the sas file. (:issue:`48048`)
85-
- :meth:`.DataFrameGroupBy.quantile` and :meth:`.SeriesGroupBy.quantile` now preserve nullable dtypes instead of casting to numpy dtypes (:issue:`37493`)
85+
- :meth:`.DataFrameGroupBy.quantile`, :meth:`.SeriesGroupBy.quantile` and :meth:`.DataFrameGroupBy.std` now preserve nullable dtypes instead of casting to numpy dtypes (:issue:`37493`)
8686
- :meth:`Series.add_suffix`, :meth:`DataFrame.add_suffix`, :meth:`Series.add_prefix` and :meth:`DataFrame.add_prefix` support an ``axis`` argument. If ``axis`` is set, the default behaviour of which axis to consider can be overwritten (:issue:`47819`)
8787
- :func:`assert_frame_equal` now shows the first element where the DataFrames differ, analogously to ``pytest``'s output (:issue:`47910`)
8888
- Added ``index`` parameter to :meth:`DataFrame.to_dict` (:issue:`46398`)

pandas/core/groupby/groupby.py

+25-16
Original file line numberDiff line numberDiff line change
@@ -1806,8 +1806,6 @@ def result_to_bool(
18061806
libgroupby.group_any_all,
18071807
numeric_only=False,
18081808
cython_dtype=np.dtype(np.int8),
1809-
needs_mask=True,
1810-
needs_nullable=True,
18111809
pre_processing=objs_to_bool,
18121810
post_processing=result_to_bool,
18131811
val_test=val_test,
@@ -2084,13 +2082,24 @@ def std(
20842082
f"{type(self).__name__}.std called with "
20852083
f"numeric_only={numeric_only} and dtype {self.obj.dtype}"
20862084
)
2085+
2086+
def _postprocessing(
2087+
vals, inference, nullable: bool = False, mask=None
2088+
) -> ArrayLike:
2089+
if nullable:
2090+
if mask.ndim == 2:
2091+
mask = mask[:, 0]
2092+
return FloatingArray(np.sqrt(vals), mask.view(np.bool_))
2093+
return np.sqrt(vals)
2094+
20872095
result = self._get_cythonized_result(
20882096
libgroupby.group_var,
20892097
cython_dtype=np.dtype(np.float64),
20902098
numeric_only=numeric_only,
20912099
needs_counts=True,
2092-
post_processing=lambda vals, inference: np.sqrt(vals),
2100+
post_processing=_postprocessing,
20932101
ddof=ddof,
2102+
how="std",
20942103
)
20952104
return result
20962105

@@ -3498,10 +3507,9 @@ def _get_cythonized_result(
34983507
cython_dtype: np.dtype,
34993508
numeric_only: bool = False,
35003509
needs_counts: bool = False,
3501-
needs_nullable: bool = False,
3502-
needs_mask: bool = False,
35033510
pre_processing=None,
35043511
post_processing=None,
3512+
how: str = "any_all",
35053513
**kwargs,
35063514
):
35073515
"""
@@ -3516,12 +3524,6 @@ def _get_cythonized_result(
35163524
Whether only numeric datatypes should be computed
35173525
needs_counts : bool, default False
35183526
Whether the counts should be a part of the Cython call
3519-
needs_mask : bool, default False
3520-
Whether boolean mask needs to be part of the Cython call
3521-
signature
3522-
needs_nullable : bool, default False
3523-
Whether a bool specifying if the input is nullable is part
3524-
of the Cython call signature
35253527
pre_processing : function, default None
35263528
Function to be applied to `values` prior to passing to Cython.
35273529
Function should return a tuple where the first element is the
@@ -3536,6 +3538,8 @@ def _get_cythonized_result(
35363538
second argument, i.e. the signature should be
35373539
(ndarray, Type). If `needs_nullable=True`, a third argument should be
35383540
`nullable`, to allow for processing specific to nullable values.
3541+
how : str, default any_all
3542+
Determines if any/all cython interface or std interface is used.
35393543
**kwargs : dict
35403544
Extra arguments to be passed back to Cython funcs
35413545
@@ -3579,26 +3583,31 @@ def blk_func(values: ArrayLike) -> ArrayLike:
35793583
vals = vals.reshape((-1, 1))
35803584
func = partial(func, values=vals)
35813585

3582-
if needs_mask:
3586+
if how != "std" or isinstance(values, BaseMaskedArray):
35833587
mask = isna(values).view(np.uint8)
35843588
if mask.ndim == 1:
35853589
mask = mask.reshape(-1, 1)
35863590
func = partial(func, mask=mask)
35873591

3588-
if needs_nullable:
3592+
if how != "std":
35893593
is_nullable = isinstance(values, BaseMaskedArray)
35903594
func = partial(func, nullable=is_nullable)
35913595

3596+
else:
3597+
result_mask = np.zeros(result.shape, dtype=np.bool_)
3598+
func = partial(func, result_mask=result_mask)
3599+
35923600
func(**kwargs) # Call func to modify indexer values in place
35933601

35943602
if values.ndim == 1:
35953603
assert result.shape[1] == 1, result.shape
35963604
result = result[:, 0]
35973605

35983606
if post_processing:
3599-
pp_kwargs = {}
3600-
if needs_nullable:
3601-
pp_kwargs["nullable"] = isinstance(values, BaseMaskedArray)
3607+
pp_kwargs: dict[str, bool | np.ndarray] = {}
3608+
pp_kwargs["nullable"] = isinstance(values, BaseMaskedArray)
3609+
if how == "std":
3610+
pp_kwargs["mask"] = result_mask
36023611

36033612
result = post_processing(result, inferences, **pp_kwargs)
36043613

pandas/tests/groupby/aggregate/test_aggregate.py

+15
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,21 @@ def test_aggregate_str_func(tsframe, groupbyfunc):
210210
tm.assert_frame_equal(result, expected)
211211

212212

213+
def test_std_masked_dtype(any_numeric_ea_dtype):
214+
# GH#35516
215+
df = DataFrame(
216+
{
217+
"a": [2, 1, 1, 1, 2, 2, 1],
218+
"b": Series([pd.NA, 1, 2, 1, 1, 1, 2], dtype="Float64"),
219+
}
220+
)
221+
result = df.groupby("a").std()
222+
expected = DataFrame(
223+
{"b": [0.57735, 0]}, index=Index([1, 2], name="a"), dtype="Float64"
224+
)
225+
tm.assert_frame_equal(result, expected)
226+
227+
213228
def test_agg_str_with_kwarg_axis_1_raises(df, reduction_func):
214229
gb = df.groupby(level=0)
215230
if reduction_func in ("idxmax", "idxmin"):

0 commit comments

Comments
 (0)