diff --git a/doc/source/whatsnew/v1.3.0.rst b/doc/source/whatsnew/v1.3.0.rst index 1a11fffbf6b4e..2487b5fca591b 100644 --- a/doc/source/whatsnew/v1.3.0.rst +++ b/doc/source/whatsnew/v1.3.0.rst @@ -832,6 +832,7 @@ Groupby/resample/rolling - Bug in :class:`core.window.RollingGroupby` where ``as_index=False`` argument in ``groupby`` was ignored (:issue:`39433`) - Bug in :meth:`.GroupBy.any` and :meth:`.GroupBy.all` raising ``ValueError`` when using with nullable type columns holding ``NA`` even with ``skipna=True`` (:issue:`40585`) - Bug in :meth:`GroupBy.cummin` and :meth:`GroupBy.cummax` incorrectly rounding integer values near the ``int64`` implementations bounds (:issue:`40767`) +- Bug in :meth:`.GroupBy.rank` with nullable dtypes incorrectly raising ``TypeError`` (:issue:`41010`) Reshaping ^^^^^^^^^ diff --git a/pandas/core/groupby/base.py b/pandas/core/groupby/base.py index 50248d5af8883..6c1d6847a0bde 100644 --- a/pandas/core/groupby/base.py +++ b/pandas/core/groupby/base.py @@ -122,8 +122,6 @@ def _gotitem(self, key, ndim, subset=None): # require postprocessing of the result by transform. cythonized_kernels = frozenset(["cumprod", "cumsum", "shift", "cummin", "cummax"]) -cython_cast_blocklist = frozenset(["rank", "count", "size", "idxmin", "idxmax"]) - # List of aggregation/reduction functions. # These map each group to a single numeric value reduction_kernels = frozenset( diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index d9bf1adf74a5e..6eddf8e9e8773 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -58,7 +58,6 @@ is_timedelta64_dtype, needs_i8_conversion, ) -from pandas.core.dtypes.dtypes import ExtensionDtype from pandas.core.dtypes.generic import ABCCategoricalIndex from pandas.core.dtypes.missing import ( isna, @@ -95,6 +94,10 @@ class WrappedCythonOp: Dispatch logic for functions defined in _libs.groupby """ + # Functions for which we do _not_ attempt to cast the cython result + # back to the original dtype. + cast_blocklist = frozenset(["rank", "count", "size", "idxmin", "idxmax"]) + def __init__(self, kind: str, how: str): self.kind = kind self.how = how @@ -564,11 +567,13 @@ def _ea_wrap_cython_operation( if is_datetime64tz_dtype(values.dtype) or is_period_dtype(values.dtype): # All of the functions implemented here are ordinal, so we can # operate on the tz-naive equivalents - values = values.view("M8[ns]") + npvalues = values.view("M8[ns]") res_values = self._cython_operation( - kind, values, how, axis, min_count, **kwargs + kind, npvalues, how, axis, min_count, **kwargs ) if how in ["rank"]: + # i.e. how in WrappedCythonOp.cast_blocklist, since + # other cast_blocklist methods dont go through cython_operation # preserve float64 dtype return res_values @@ -582,12 +587,16 @@ def _ea_wrap_cython_operation( res_values = self._cython_operation( kind, values, how, axis, min_count, **kwargs ) - dtype = maybe_cast_result_dtype(orig_values.dtype, how) - if isinstance(dtype, ExtensionDtype): - cls = dtype.construct_array_type() - return cls._from_sequence(res_values, dtype=dtype) + if how in ["rank"]: + # i.e. how in WrappedCythonOp.cast_blocklist, since + # other cast_blocklist methods dont go through cython_operation + return res_values - return res_values + dtype = maybe_cast_result_dtype(orig_values.dtype, how) + # error: Item "dtype[Any]" of "Union[dtype[Any], ExtensionDtype]" + # has no attribute "construct_array_type" + cls = dtype.construct_array_type() # type: ignore[union-attr] + return cls._from_sequence(res_values, dtype=dtype) elif is_float_dtype(values.dtype): # FloatingArray @@ -595,8 +604,16 @@ def _ea_wrap_cython_operation( res_values = self._cython_operation( kind, values, how, axis, min_count, **kwargs ) - result = type(orig_values)._from_sequence(res_values) - return result + if how in ["rank"]: + # i.e. how in WrappedCythonOp.cast_blocklist, since + # other cast_blocklist methods dont go through cython_operation + return res_values + + dtype = maybe_cast_result_dtype(orig_values.dtype, how) + # error: Item "dtype[Any]" of "Union[dtype[Any], ExtensionDtype]" + # has no attribute "construct_array_type" + cls = dtype.construct_array_type() # type: ignore[union-attr] + return cls._from_sequence(res_values, dtype=dtype) raise NotImplementedError( f"function is not implemented for this dtype: {values.dtype}" @@ -711,9 +728,9 @@ def _cython_operation( result = result.T - if how not in base.cython_cast_blocklist: + if how not in cy_op.cast_blocklist: # e.g. if we are int64 and need to restore to datetime64/timedelta64 - # "rank" is the only member of cython_cast_blocklist we get here + # "rank" is the only member of cast_blocklist we get here dtype = maybe_cast_result_dtype(orig_values.dtype, how) op_result = maybe_downcast_to_dtype(result, dtype) else: diff --git a/pandas/tests/groupby/test_counting.py b/pandas/tests/groupby/test_counting.py index 1317f0f68216a..73b2d8ac2c1f5 100644 --- a/pandas/tests/groupby/test_counting.py +++ b/pandas/tests/groupby/test_counting.py @@ -209,6 +209,7 @@ def test_ngroup_respects_groupby_order(self): [ [Timestamp(f"2016-05-{i:02d} 20:09:25+00:00") for i in range(1, 4)], [Timestamp(f"2016-05-{i:02d} 20:09:25") for i in range(1, 4)], + [Timestamp(f"2016-05-{i:02d} 20:09:25", tz="UTC") for i in range(1, 4)], [Timedelta(x, unit="h") for x in range(1, 4)], [Period(freq="2W", year=2017, month=x) for x in range(1, 4)], ], diff --git a/pandas/tests/groupby/test_function.py b/pandas/tests/groupby/test_function.py index d7020e2ffd701..46985ff956788 100644 --- a/pandas/tests/groupby/test_function.py +++ b/pandas/tests/groupby/test_function.py @@ -495,6 +495,8 @@ def test_idxmin_idxmax_returns_int_types(func, values): df["c_date_tz"] = df["c_date"].dt.tz_localize("US/Pacific") df["c_timedelta"] = df["c_date"] - df["c_date"].iloc[0] df["c_period"] = df["c_date"].dt.to_period("W") + df["c_Integer"] = df["c_int"].astype("Int64") + df["c_Floating"] = df["c_float"].astype("Float64") result = getattr(df.groupby("name"), func)() @@ -502,6 +504,8 @@ def test_idxmin_idxmax_returns_int_types(func, values): expected["c_date_tz"] = expected["c_date"] expected["c_timedelta"] = expected["c_date"] expected["c_period"] = expected["c_date"] + expected["c_Integer"] = expected["c_int"] + expected["c_Floating"] = expected["c_float"] tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/groupby/test_groupby.py b/pandas/tests/groupby/test_groupby.py index 2dab22910a0c9..c5620d6d8c06c 100644 --- a/pandas/tests/groupby/test_groupby.py +++ b/pandas/tests/groupby/test_groupby.py @@ -1732,6 +1732,8 @@ def test_pivot_table_values_key_error(): [to_datetime(0)], [date_range(0, 1, 1, tz="US/Eastern")], [pd.array([0], dtype="Int64")], + [pd.array([0], dtype="Float64")], + [pd.array([False], dtype="boolean")], ], ) @pytest.mark.parametrize("method", ["attr", "agg", "apply"]) diff --git a/pandas/tests/groupby/test_rank.py b/pandas/tests/groupby/test_rank.py index 2e666c27386b4..da88ea5f05107 100644 --- a/pandas/tests/groupby/test_rank.py +++ b/pandas/tests/groupby/test_rank.py @@ -444,8 +444,19 @@ def test_rank_resets_each_group(pct, exp): tm.assert_frame_equal(result, exp_df) -def test_rank_avg_even_vals(): +@pytest.mark.parametrize( + "dtype", ["int64", "int32", "uint64", "uint32", "float64", "float32"] +) +@pytest.mark.parametrize("upper", [True, False]) +def test_rank_avg_even_vals(dtype, upper): + if upper: + # use IntegerDtype/FloatingDtype + dtype = dtype[0].upper() + dtype[1:] + dtype = dtype.replace("Ui", "UI") df = DataFrame({"key": ["a"] * 4, "val": [1] * 4}) + df["val"] = df["val"].astype(dtype) + assert df["val"].dtype == dtype + result = df.groupby("key").rank() exp_df = DataFrame([2.5, 2.5, 2.5, 2.5], columns=["val"]) tm.assert_frame_equal(result, exp_df)