Skip to content

Commit 9e09142

Browse files
authored
BUG: groupby.rank with MaskedArray incorrect casting (#41010)
1 parent 1b95e1b commit 9e09142

File tree

7 files changed

+49
-15
lines changed

7 files changed

+49
-15
lines changed

doc/source/whatsnew/v1.3.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,7 @@ Groupby/resample/rolling
834834
- Bug in :class:`core.window.RollingGroupby` where ``as_index=False`` argument in ``groupby`` was ignored (:issue:`39433`)
835835
- Bug in :meth:`.GroupBy.any` and :meth:`.GroupBy.all` raising ``ValueError`` when using with nullable type columns holding ``NA`` even with ``skipna=True`` (:issue:`40585`)
836836
- Bug in :meth:`GroupBy.cummin` and :meth:`GroupBy.cummax` incorrectly rounding integer values near the ``int64`` implementations bounds (:issue:`40767`)
837+
- Bug in :meth:`.GroupBy.rank` with nullable dtypes incorrectly raising ``TypeError`` (:issue:`41010`)
837838

838839
Reshaping
839840
^^^^^^^^^

pandas/core/groupby/base.py

-2
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,6 @@
4747
# require postprocessing of the result by transform.
4848
cythonized_kernels = frozenset(["cumprod", "cumsum", "shift", "cummin", "cummax"])
4949

50-
cython_cast_blocklist = frozenset(["rank", "count", "size", "idxmin", "idxmax"])
51-
5250
# List of aggregation/reduction functions.
5351
# These map each group to a single numeric value
5452
reduction_kernels = frozenset(

pandas/core/groupby/ops.py

+29-12
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
is_timedelta64_dtype,
5959
needs_i8_conversion,
6060
)
61-
from pandas.core.dtypes.dtypes import ExtensionDtype
6261
from pandas.core.dtypes.generic import ABCCategoricalIndex
6362
from pandas.core.dtypes.missing import (
6463
isna,
@@ -95,6 +94,10 @@ class WrappedCythonOp:
9594
Dispatch logic for functions defined in _libs.groupby
9695
"""
9796

97+
# Functions for which we do _not_ attempt to cast the cython result
98+
# back to the original dtype.
99+
cast_blocklist = frozenset(["rank", "count", "size", "idxmin", "idxmax"])
100+
98101
def __init__(self, kind: str, how: str):
99102
self.kind = kind
100103
self.how = how
@@ -564,11 +567,13 @@ def _ea_wrap_cython_operation(
564567
if is_datetime64tz_dtype(values.dtype) or is_period_dtype(values.dtype):
565568
# All of the functions implemented here are ordinal, so we can
566569
# operate on the tz-naive equivalents
567-
values = values.view("M8[ns]")
570+
npvalues = values.view("M8[ns]")
568571
res_values = self._cython_operation(
569-
kind, values, how, axis, min_count, **kwargs
572+
kind, npvalues, how, axis, min_count, **kwargs
570573
)
571574
if how in ["rank"]:
575+
# i.e. how in WrappedCythonOp.cast_blocklist, since
576+
# other cast_blocklist methods dont go through cython_operation
572577
# preserve float64 dtype
573578
return res_values
574579

@@ -582,21 +587,33 @@ def _ea_wrap_cython_operation(
582587
res_values = self._cython_operation(
583588
kind, values, how, axis, min_count, **kwargs
584589
)
585-
dtype = maybe_cast_result_dtype(orig_values.dtype, how)
586-
if isinstance(dtype, ExtensionDtype):
587-
cls = dtype.construct_array_type()
588-
return cls._from_sequence(res_values, dtype=dtype)
590+
if how in ["rank"]:
591+
# i.e. how in WrappedCythonOp.cast_blocklist, since
592+
# other cast_blocklist methods dont go through cython_operation
593+
return res_values
589594

590-
return res_values
595+
dtype = maybe_cast_result_dtype(orig_values.dtype, how)
596+
# error: Item "dtype[Any]" of "Union[dtype[Any], ExtensionDtype]"
597+
# has no attribute "construct_array_type"
598+
cls = dtype.construct_array_type() # type: ignore[union-attr]
599+
return cls._from_sequence(res_values, dtype=dtype)
591600

592601
elif is_float_dtype(values.dtype):
593602
# FloatingArray
594603
values = values.to_numpy(values.dtype.numpy_dtype, na_value=np.nan)
595604
res_values = self._cython_operation(
596605
kind, values, how, axis, min_count, **kwargs
597606
)
598-
result = type(orig_values)._from_sequence(res_values)
599-
return result
607+
if how in ["rank"]:
608+
# i.e. how in WrappedCythonOp.cast_blocklist, since
609+
# other cast_blocklist methods dont go through cython_operation
610+
return res_values
611+
612+
dtype = maybe_cast_result_dtype(orig_values.dtype, how)
613+
# error: Item "dtype[Any]" of "Union[dtype[Any], ExtensionDtype]"
614+
# has no attribute "construct_array_type"
615+
cls = dtype.construct_array_type() # type: ignore[union-attr]
616+
return cls._from_sequence(res_values, dtype=dtype)
600617

601618
raise NotImplementedError(
602619
f"function is not implemented for this dtype: {values.dtype}"
@@ -711,9 +728,9 @@ def _cython_operation(
711728

712729
result = result.T
713730

714-
if how not in base.cython_cast_blocklist:
731+
if how not in cy_op.cast_blocklist:
715732
# e.g. if we are int64 and need to restore to datetime64/timedelta64
716-
# "rank" is the only member of cython_cast_blocklist we get here
733+
# "rank" is the only member of cast_blocklist we get here
717734
dtype = maybe_cast_result_dtype(orig_values.dtype, how)
718735
op_result = maybe_downcast_to_dtype(result, dtype)
719736
else:

pandas/tests/groupby/test_counting.py

+1
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ def test_ngroup_respects_groupby_order(self):
209209
[
210210
[Timestamp(f"2016-05-{i:02d} 20:09:25+00:00") for i in range(1, 4)],
211211
[Timestamp(f"2016-05-{i:02d} 20:09:25") for i in range(1, 4)],
212+
[Timestamp(f"2016-05-{i:02d} 20:09:25", tz="UTC") for i in range(1, 4)],
212213
[Timedelta(x, unit="h") for x in range(1, 4)],
213214
[Period(freq="2W", year=2017, month=x) for x in range(1, 4)],
214215
],

pandas/tests/groupby/test_function.py

+4
Original file line numberDiff line numberDiff line change
@@ -495,13 +495,17 @@ def test_idxmin_idxmax_returns_int_types(func, values):
495495
df["c_date_tz"] = df["c_date"].dt.tz_localize("US/Pacific")
496496
df["c_timedelta"] = df["c_date"] - df["c_date"].iloc[0]
497497
df["c_period"] = df["c_date"].dt.to_period("W")
498+
df["c_Integer"] = df["c_int"].astype("Int64")
499+
df["c_Floating"] = df["c_float"].astype("Float64")
498500

499501
result = getattr(df.groupby("name"), func)()
500502

501503
expected = DataFrame(values, index=Index(["A", "B"], name="name"))
502504
expected["c_date_tz"] = expected["c_date"]
503505
expected["c_timedelta"] = expected["c_date"]
504506
expected["c_period"] = expected["c_date"]
507+
expected["c_Integer"] = expected["c_int"]
508+
expected["c_Floating"] = expected["c_float"]
505509

506510
tm.assert_frame_equal(result, expected)
507511

pandas/tests/groupby/test_groupby.py

+2
Original file line numberDiff line numberDiff line change
@@ -1732,6 +1732,8 @@ def test_pivot_table_values_key_error():
17321732
[to_datetime(0)],
17331733
[date_range(0, 1, 1, tz="US/Eastern")],
17341734
[pd.array([0], dtype="Int64")],
1735+
[pd.array([0], dtype="Float64")],
1736+
[pd.array([False], dtype="boolean")],
17351737
],
17361738
)
17371739
@pytest.mark.parametrize("method", ["attr", "agg", "apply"])

pandas/tests/groupby/test_rank.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -444,8 +444,19 @@ def test_rank_resets_each_group(pct, exp):
444444
tm.assert_frame_equal(result, exp_df)
445445

446446

447-
def test_rank_avg_even_vals():
447+
@pytest.mark.parametrize(
448+
"dtype", ["int64", "int32", "uint64", "uint32", "float64", "float32"]
449+
)
450+
@pytest.mark.parametrize("upper", [True, False])
451+
def test_rank_avg_even_vals(dtype, upper):
452+
if upper:
453+
# use IntegerDtype/FloatingDtype
454+
dtype = dtype[0].upper() + dtype[1:]
455+
dtype = dtype.replace("Ui", "UI")
448456
df = DataFrame({"key": ["a"] * 4, "val": [1] * 4})
457+
df["val"] = df["val"].astype(dtype)
458+
assert df["val"].dtype == dtype
459+
449460
result = df.groupby("key").rank()
450461
exp_df = DataFrame([2.5, 2.5, 2.5, 2.5], columns=["val"])
451462
tm.assert_frame_equal(result, exp_df)

0 commit comments

Comments
 (0)