Skip to content

BUG: nullable groupby result dtypes #46197

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions pandas/core/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1417,3 +1417,23 @@ def construct_array_type(cls) -> type_t[BaseMaskedArray]:
type
"""
raise NotImplementedError

@classmethod
def from_numpy_dtype(cls, dtype: np.dtype) -> BaseMaskedDtype:
"""
Construct the MaskedDtype corresponding to the given numpy dtype.
"""
if dtype.kind == "b":
from pandas.core.arrays.boolean import BooleanDtype

return BooleanDtype()
elif dtype.kind in ["i", "u"]:
from pandas.core.arrays.integer import INT_STR_TO_DTYPE
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ashame you have to do these imports ......

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yah im vaguely hoping to get rid of all these and just have MaskedArray/MaskedDtype, but that's a long ways away at best


return INT_STR_TO_DTYPE[dtype.name]
elif dtype.kind == "f":
from pandas.core.arrays.floating import FLOAT_STR_TO_DTYPE

return FLOAT_STR_TO_DTYPE[dtype.name]
else:
raise NotImplementedError(dtype)
49 changes: 19 additions & 30 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
Iterator,
Sequence,
final,
overload,
)

import numpy as np
Expand Down Expand Up @@ -57,7 +56,6 @@
is_timedelta64_dtype,
needs_i8_conversion,
)
from pandas.core.dtypes.dtypes import ExtensionDtype
from pandas.core.dtypes.missing import (
isna,
maybe_fill,
Expand All @@ -70,14 +68,8 @@
TimedeltaArray,
)
from pandas.core.arrays.boolean import BooleanDtype
from pandas.core.arrays.floating import (
Float64Dtype,
FloatingDtype,
)
from pandas.core.arrays.integer import (
Int64Dtype,
IntegerDtype,
)
from pandas.core.arrays.floating import FloatingDtype
from pandas.core.arrays.integer import IntegerDtype
from pandas.core.arrays.masked import (
BaseMaskedArray,
BaseMaskedDtype,
Expand Down Expand Up @@ -277,41 +269,27 @@ def _get_out_dtype(self, dtype: np.dtype) -> np.dtype:
out_dtype = "object"
return np.dtype(out_dtype)

@overload
def _get_result_dtype(self, dtype: np.dtype) -> np.dtype:
... # pragma: no cover

@overload
def _get_result_dtype(self, dtype: ExtensionDtype) -> ExtensionDtype:
... # pragma: no cover

# TODO: general case implementation overridable by EAs.
def _get_result_dtype(self, dtype: DtypeObj) -> DtypeObj:
"""
Get the desired dtype of a result based on the
input dtype and how it was computed.

Parameters
----------
dtype : np.dtype or ExtensionDtype
Input dtype.
dtype : np.dtype

Returns
-------
np.dtype or ExtensionDtype
np.dtype
The desired dtype of the result.
"""
how = self.how

if how in ["add", "cumsum", "sum", "prod"]:
if dtype == np.dtype(bool):
return np.dtype(np.int64)
elif isinstance(dtype, (BooleanDtype, IntegerDtype)):
return Int64Dtype()
elif how in ["mean", "median", "var"]:
if isinstance(dtype, (BooleanDtype, IntegerDtype)):
return Float64Dtype()
elif is_float_dtype(dtype) or is_complex_dtype(dtype):
if is_float_dtype(dtype) or is_complex_dtype(dtype):
return dtype
elif is_numeric_dtype(dtype):
return np.dtype(np.float64)
Expand Down Expand Up @@ -390,8 +368,18 @@ def _reconstruct_ea_result(
Construct an ExtensionArray result from an ndarray result.
"""

if isinstance(values.dtype, (BaseMaskedDtype, StringDtype)):
dtype = self._get_result_dtype(values.dtype)
if isinstance(values.dtype, StringDtype):
dtype = values.dtype
cls = dtype.construct_array_type()
return cls._from_sequence(res_values, dtype=dtype)

elif isinstance(values.dtype, BaseMaskedDtype):
new_dtype = self._get_result_dtype(values.dtype.numpy_dtype)
# error: Incompatible types in assignment (expression has type
# "BaseMaskedDtype", variable has type "StringDtype")
dtype = BaseMaskedDtype.from_numpy_dtype( # type: ignore[assignment]
new_dtype
)
cls = dtype.construct_array_type()
return cls._from_sequence(res_values, dtype=dtype)

Expand Down Expand Up @@ -433,7 +421,8 @@ def _masked_ea_wrap_cython_operation(
**kwargs,
)

dtype = self._get_result_dtype(orig_values.dtype)
new_dtype = self._get_result_dtype(orig_values.dtype.numpy_dtype)
dtype = BaseMaskedDtype.from_numpy_dtype(new_dtype)
# TODO: avoid cast as res_values *should* already have the right
# dtype; last attempt ran into trouble on 32bit linux build
res_values = res_values.astype(dtype.type, copy=False)
Expand Down
8 changes: 7 additions & 1 deletion pandas/tests/groupby/aggregate/test_cython.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import numpy as np
import pytest

from pandas.core.dtypes.common import is_float_dtype
from pandas.core.dtypes.common import (
is_float_dtype,
is_integer_dtype,
)

import pandas as pd
from pandas import (
Expand Down Expand Up @@ -369,6 +372,9 @@ def test_cython_agg_EA_known_dtypes(data, op_name, action, with_na):
# for any int/bool use Int64, for float preserve dtype
if is_float_dtype(data.dtype):
expected_dtype = data.dtype
elif is_integer_dtype(data.dtype):
# match the numpy dtype we'd get with the non-nullable analogue
expected_dtype = data.dtype
else:
expected_dtype = pd.Int64Dtype()
elif action == "always_float":
Expand Down