Skip to content

Commit b309683

Browse files
jbrockmendelyeshsurya
authored andcommitted
BUG: groupby.rank with MaskedArray incorrect casting (pandas-dev#41010)
1 parent 6624c5b commit b309683

File tree

2 files changed

+30
-12
lines changed

2 files changed

+30
-12
lines changed

doc/source/whatsnew/v1.3.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -837,6 +837,7 @@ Groupby/resample/rolling
837837
- Bug in :class:`core.window.RollingGroupby` where ``as_index=False`` argument in ``groupby`` was ignored (:issue:`39433`)
838838
- Bug in :meth:`.GroupBy.any` and :meth:`.GroupBy.all` raising ``ValueError`` when using with nullable type columns holding ``NA`` even with ``skipna=True`` (:issue:`40585`)
839839
- Bug in :meth:`GroupBy.cummin` and :meth:`GroupBy.cummax` incorrectly rounding integer values near the ``int64`` implementations bounds (:issue:`40767`)
840+
- Bug in :meth:`.GroupBy.rank` with nullable dtypes incorrectly raising ``TypeError`` (:issue:`41010`)
840841

841842

842843
Reshaping

pandas/core/groupby/ops.py

+29-12
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
is_timedelta64_dtype,
5959
needs_i8_conversion,
6060
)
61-
from pandas.core.dtypes.dtypes import ExtensionDtype
6261
from pandas.core.dtypes.generic import ABCCategoricalIndex
6362
from pandas.core.dtypes.missing import (
6463
isna,
@@ -95,6 +94,10 @@ class WrappedCythonOp:
9594
Dispatch logic for functions defined in _libs.groupby
9695
"""
9796

97+
# Functions for which we do _not_ attempt to cast the cython result
98+
# back to the original dtype.
99+
cast_blocklist = frozenset(["rank", "count", "size", "idxmin", "idxmax"])
100+
98101
def __init__(self, kind: str, how: str):
99102
self.kind = kind
100103
self.how = how
@@ -564,11 +567,13 @@ def _ea_wrap_cython_operation(
564567
if is_datetime64tz_dtype(values.dtype) or is_period_dtype(values.dtype):
565568
# All of the functions implemented here are ordinal, so we can
566569
# operate on the tz-naive equivalents
567-
values = values.view("M8[ns]")
570+
npvalues = values.view("M8[ns]")
568571
res_values = self._cython_operation(
569-
kind, values, how, axis, min_count, **kwargs
572+
kind, npvalues, how, axis, min_count, **kwargs
570573
)
571574
if how in ["rank"]:
575+
# i.e. how in WrappedCythonOp.cast_blocklist, since
576+
# other cast_blocklist methods dont go through cython_operation
572577
# preserve float64 dtype
573578
return res_values
574579

@@ -582,21 +587,33 @@ def _ea_wrap_cython_operation(
582587
res_values = self._cython_operation(
583588
kind, values, how, axis, min_count, **kwargs
584589
)
585-
dtype = maybe_cast_result_dtype(orig_values.dtype, how)
586-
if isinstance(dtype, ExtensionDtype):
587-
cls = dtype.construct_array_type()
588-
return cls._from_sequence(res_values, dtype=dtype)
590+
if how in ["rank"]:
591+
# i.e. how in WrappedCythonOp.cast_blocklist, since
592+
# other cast_blocklist methods dont go through cython_operation
593+
return res_values
589594

590-
return res_values
595+
dtype = maybe_cast_result_dtype(orig_values.dtype, how)
596+
# error: Item "dtype[Any]" of "Union[dtype[Any], ExtensionDtype]"
597+
# has no attribute "construct_array_type"
598+
cls = dtype.construct_array_type() # type: ignore[union-attr]
599+
return cls._from_sequence(res_values, dtype=dtype)
591600

592601
elif is_float_dtype(values.dtype):
593602
# FloatingArray
594603
values = values.to_numpy(values.dtype.numpy_dtype, na_value=np.nan)
595604
res_values = self._cython_operation(
596605
kind, values, how, axis, min_count, **kwargs
597606
)
598-
result = type(orig_values)._from_sequence(res_values)
599-
return result
607+
if how in ["rank"]:
608+
# i.e. how in WrappedCythonOp.cast_blocklist, since
609+
# other cast_blocklist methods dont go through cython_operation
610+
return res_values
611+
612+
dtype = maybe_cast_result_dtype(orig_values.dtype, how)
613+
# error: Item "dtype[Any]" of "Union[dtype[Any], ExtensionDtype]"
614+
# has no attribute "construct_array_type"
615+
cls = dtype.construct_array_type() # type: ignore[union-attr]
616+
return cls._from_sequence(res_values, dtype=dtype)
600617

601618
raise NotImplementedError(
602619
f"function is not implemented for this dtype: {values.dtype}"
@@ -711,9 +728,9 @@ def _cython_operation(
711728

712729
result = result.T
713730

714-
if how not in base.cython_cast_blocklist:
731+
if how not in cy_op.cast_blocklist:
715732
# e.g. if we are int64 and need to restore to datetime64/timedelta64
716-
# "rank" is the only member of cython_cast_blocklist we get here
733+
# "rank" is the only member of cast_blocklist we get here
717734
dtype = maybe_cast_result_dtype(orig_values.dtype, how)
718735
op_result = maybe_downcast_to_dtype(result, dtype)
719736
else:

0 commit comments

Comments
 (0)