Skip to content

ENH: Preserve ea dtype in groupby.std #50375

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v2.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ be set to ``"pyarrow"`` to return pyarrow-backed, nullable :class:`ArrowDtype` (
Other enhancements
^^^^^^^^^^^^^^^^^^
- :func:`read_sas` now supports using ``encoding='infer'`` to correctly read and use the encoding specified by the sas file. (:issue:`48048`)
- :meth:`.DataFrameGroupBy.quantile` and :meth:`.SeriesGroupBy.quantile` now preserve nullable dtypes instead of casting to numpy dtypes (:issue:`37493`)
- :meth:`.DataFrameGroupBy.quantile`, :meth:`.SeriesGroupBy.quantile` and :meth:`.DataFrameGroupBy.std` now preserve nullable dtypes instead of casting to numpy dtypes (:issue:`37493`)
- :meth:`Series.add_suffix`, :meth:`DataFrame.add_suffix`, :meth:`Series.add_prefix` and :meth:`DataFrame.add_prefix` support an ``axis`` argument. If ``axis`` is set, the default behaviour of which axis to consider can be overwritten (:issue:`47819`)
- :func:`assert_frame_equal` now shows the first element where the DataFrames differ, analogously to ``pytest``'s output (:issue:`47910`)
- Added ``index`` parameter to :meth:`DataFrame.to_dict` (:issue:`46398`)
Expand Down
41 changes: 25 additions & 16 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1807,8 +1807,6 @@ def result_to_bool(
libgroupby.group_any_all,
numeric_only=False,
cython_dtype=np.dtype(np.int8),
needs_mask=True,
needs_nullable=True,
pre_processing=objs_to_bool,
post_processing=result_to_bool,
val_test=val_test,
Expand Down Expand Up @@ -2085,13 +2083,24 @@ def std(
f"{type(self).__name__}.std called with "
f"numeric_only={numeric_only} and dtype {self.obj.dtype}"
)

def _postprocessing(
vals, inference, nullable: bool = False, mask=None
) -> ArrayLike:
if nullable:
if mask.ndim == 2:
mask = mask[:, 0]
return FloatingArray(np.sqrt(vals), mask.view(np.bool_))
return np.sqrt(vals)

result = self._get_cythonized_result(
libgroupby.group_var,
cython_dtype=np.dtype(np.float64),
numeric_only=numeric_only,
needs_counts=True,
post_processing=lambda vals, inference: np.sqrt(vals),
post_processing=_postprocessing,
ddof=ddof,
how="std",
)
return result

Expand Down Expand Up @@ -3501,10 +3510,9 @@ def _get_cythonized_result(
cython_dtype: np.dtype,
numeric_only: bool = False,
needs_counts: bool = False,
needs_nullable: bool = False,
needs_mask: bool = False,
pre_processing=None,
post_processing=None,
how: str = "any_all",
**kwargs,
):
"""
Expand All @@ -3519,12 +3527,6 @@ def _get_cythonized_result(
Whether only numeric datatypes should be computed
needs_counts : bool, default False
Whether the counts should be a part of the Cython call
needs_mask : bool, default False
Whether boolean mask needs to be part of the Cython call
signature
needs_nullable : bool, default False
Whether a bool specifying if the input is nullable is part
of the Cython call signature
pre_processing : function, default None
Function to be applied to `values` prior to passing to Cython.
Function should return a tuple where the first element is the
Expand All @@ -3539,6 +3541,8 @@ def _get_cythonized_result(
second argument, i.e. the signature should be
(ndarray, Type). If `needs_nullable=True`, a third argument should be
`nullable`, to allow for processing specific to nullable values.
how : str, default any_all
Determines if any/all cython interface or std interface is used.
**kwargs : dict
Extra arguments to be passed back to Cython funcs

Expand Down Expand Up @@ -3582,26 +3586,31 @@ def blk_func(values: ArrayLike) -> ArrayLike:
vals = vals.reshape((-1, 1))
func = partial(func, values=vals)

if needs_mask:
if how != "std" or isinstance(values, BaseMaskedArray):
mask = isna(values).view(np.uint8)
if mask.ndim == 1:
mask = mask.reshape(-1, 1)
func = partial(func, mask=mask)

if needs_nullable:
if how != "std":
is_nullable = isinstance(values, BaseMaskedArray)
func = partial(func, nullable=is_nullable)

else:
result_mask = np.zeros(result.shape, dtype=np.bool_)
func = partial(func, result_mask=result_mask)

func(**kwargs) # Call func to modify indexer values in place

if values.ndim == 1:
assert result.shape[1] == 1, result.shape
result = result[:, 0]

if post_processing:
pp_kwargs = {}
if needs_nullable:
pp_kwargs["nullable"] = isinstance(values, BaseMaskedArray)
pp_kwargs: dict[str, bool | np.ndarray] = {}
pp_kwargs["nullable"] = isinstance(values, BaseMaskedArray)
if how == "std":
pp_kwargs["mask"] = result_mask

result = post_processing(result, inferences, **pp_kwargs)

Expand Down
15 changes: 15 additions & 0 deletions pandas/tests/groupby/aggregate/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,21 @@ def test_aggregate_str_func(tsframe, groupbyfunc):
tm.assert_frame_equal(result, expected)


def test_std_masked_dtype(any_numeric_ea_dtype):
# GH#35516
df = DataFrame(
{
"a": [2, 1, 1, 1, 2, 2, 1],
"b": Series([pd.NA, 1, 2, 1, 1, 1, 2], dtype="Float64"),
}
)
result = df.groupby("a").std()
expected = DataFrame(
{"b": [0.57735, 0]}, index=Index([1, 2], name="a"), dtype="Float64"
)
tm.assert_frame_equal(result, expected)


def test_agg_str_with_kwarg_axis_1_raises(df, reduction_func):
gb = df.groupby(level=0)
if reduction_func in ("idxmax", "idxmin"):
Expand Down