Skip to content

Commit c6c809d

Browse files
jbrockmendelyehoshuadimarsky
authored andcommitted
ENH: support mask in libalgos.rank (pandas-dev#46932)
1 parent 59d967d commit c6c809d

File tree

6 files changed

+75
-8
lines changed

6 files changed

+75
-8
lines changed

pandas/_libs/algos.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def rank_1d(
109109
ascending: bool = ...,
110110
pct: bool = ...,
111111
na_option=...,
112+
mask: npt.NDArray[np.bool_] | None = ...,
112113
) -> np.ndarray: ... # np.ndarray[float64_t, ndim=1]
113114
def rank_2d(
114115
in_arr: np.ndarray, # ndarray[numeric_object_t, ndim=2]

pandas/_libs/algos.pyx

+7-2
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,7 @@ def rank_1d(
889889
bint ascending=True,
890890
bint pct=False,
891891
na_option="keep",
892+
const uint8_t[:] mask=None,
892893
):
893894
"""
894895
Fast NaN-friendly version of ``scipy.stats.rankdata``.
@@ -918,6 +919,8 @@ def rank_1d(
918919
* keep: leave NA values where they are
919920
* top: smallest rank if ascending
920921
* bottom: smallest rank if descending
922+
mask : np.ndarray[bool], optional, default None
923+
Specify locations to be treated as NA, for e.g. Categorical.
921924
"""
922925
cdef:
923926
TiebreakEnumType tiebreak
@@ -927,7 +930,6 @@ def rank_1d(
927930
float64_t[::1] out
928931
ndarray[numeric_object_t, ndim=1] masked_vals
929932
numeric_object_t[:] masked_vals_memview
930-
uint8_t[:] mask
931933
bint keep_na, nans_rank_highest, check_labels, check_mask
932934
numeric_object_t nan_fill_val
933935

@@ -956,6 +958,7 @@ def rank_1d(
956958
or numeric_object_t is object
957959
or (numeric_object_t is int64_t and is_datetimelike)
958960
)
961+
check_mask = check_mask or mask is not None
959962

960963
# Copy values into new array in order to fill missing data
961964
# with mask, without obfuscating location of missing data
@@ -965,7 +968,9 @@ def rank_1d(
965968
else:
966969
masked_vals = values.copy()
967970

968-
if numeric_object_t is object:
971+
if mask is not None:
972+
pass
973+
elif numeric_object_t is object:
969974
mask = missing.isnaobj(masked_vals)
970975
elif numeric_object_t is int64_t and is_datetimelike:
971976
mask = (masked_vals == NPY_NAT).astype(np.uint8)

pandas/_libs/groupby.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def group_rank(
128128
ascending: bool = ...,
129129
pct: bool = ...,
130130
na_option: Literal["keep", "top", "bottom"] = ...,
131+
mask: npt.NDArray[np.bool_] | None = ...,
131132
) -> None: ...
132133
def group_max(
133134
out: np.ndarray, # groupby_t[:, ::1]

pandas/_libs/groupby.pyx

+10-1
Original file line numberDiff line numberDiff line change
@@ -1262,6 +1262,7 @@ def group_rank(
12621262
bint ascending=True,
12631263
bint pct=False,
12641264
str na_option="keep",
1265+
const uint8_t[:, :] mask=None,
12651266
) -> None:
12661267
"""
12671268
Provides the rank of values within each group.
@@ -1294,6 +1295,7 @@ def group_rank(
12941295
* keep: leave NA values where they are
12951296
* top: smallest rank if ascending
12961297
* bottom: smallest rank if descending
1298+
mask : np.ndarray[bool] or None, default None
12971299

12981300
Notes
12991301
-----
@@ -1302,18 +1304,25 @@ def group_rank(
13021304
cdef:
13031305
Py_ssize_t i, k, N
13041306
ndarray[float64_t, ndim=1] result
1307+
const uint8_t[:] sub_mask
13051308

13061309
N = values.shape[1]
13071310

13081311
for k in range(N):
1312+
if mask is None:
1313+
sub_mask = None
1314+
else:
1315+
sub_mask = mask[:, k]
1316+
13091317
result = rank_1d(
13101318
values=values[:, k],
13111319
labels=labels,
13121320
is_datetimelike=is_datetimelike,
13131321
ties_method=ties_method,
13141322
ascending=ascending,
13151323
pct=pct,
1316-
na_option=na_option
1324+
na_option=na_option,
1325+
mask=sub_mask,
13171326
)
13181327
for i in range(len(result)):
13191328
# TODO: why can't we do out[:, k] = result?

pandas/core/groupby/ops.py

+40-5
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
ensure_platform_int,
4747
is_1d_only_ea_dtype,
4848
is_bool_dtype,
49-
is_categorical_dtype,
5049
is_complex_dtype,
5150
is_datetime64_any_dtype,
5251
is_float_dtype,
@@ -56,12 +55,14 @@
5655
is_timedelta64_dtype,
5756
needs_i8_conversion,
5857
)
58+
from pandas.core.dtypes.dtypes import CategoricalDtype
5959
from pandas.core.dtypes.missing import (
6060
isna,
6161
maybe_fill,
6262
)
6363

6464
from pandas.core.arrays import (
65+
Categorical,
6566
DatetimeArray,
6667
ExtensionArray,
6768
PeriodArray,
@@ -142,7 +143,15 @@ def __init__(self, kind: str, how: str, has_dropped_na: bool) -> None:
142143

143144
# "group_any" and "group_all" are also support masks, but don't go
144145
# through WrappedCythonOp
145-
_MASKED_CYTHON_FUNCTIONS = {"cummin", "cummax", "min", "max", "last", "first"}
146+
_MASKED_CYTHON_FUNCTIONS = {
147+
"cummin",
148+
"cummax",
149+
"min",
150+
"max",
151+
"last",
152+
"first",
153+
"rank",
154+
}
146155

147156
_cython_arity = {"ohlc": 4} # OHLC
148157

@@ -229,12 +238,17 @@ def _disallow_invalid_ops(self, dtype: DtypeObj, is_numeric: bool = False):
229238
# never an invalid op for those dtypes, so return early as fastpath
230239
return
231240

232-
if is_categorical_dtype(dtype):
241+
if isinstance(dtype, CategoricalDtype):
233242
# NotImplementedError for methods that can fall back to a
234243
# non-cython implementation.
235244
if how in ["add", "prod", "cumsum", "cumprod"]:
236245
raise TypeError(f"{dtype} type does not support {how} operations")
237-
raise NotImplementedError(f"{dtype} dtype not supported")
246+
elif how not in ["rank"]:
247+
# only "rank" is implemented in cython
248+
raise NotImplementedError(f"{dtype} dtype not supported")
249+
elif not dtype.ordered:
250+
# TODO: TypeError?
251+
raise NotImplementedError(f"{dtype} dtype not supported")
238252

239253
elif is_sparse(dtype):
240254
# categoricals are only 1d, so we
@@ -332,6 +346,25 @@ def _ea_wrap_cython_operation(
332346
**kwargs,
333347
)
334348

349+
elif isinstance(values, Categorical) and self.uses_mask():
350+
assert self.how == "rank" # the only one implemented ATM
351+
assert values.ordered # checked earlier
352+
mask = values.isna()
353+
npvalues = values._ndarray
354+
355+
res_values = self._cython_op_ndim_compat(
356+
npvalues,
357+
min_count=min_count,
358+
ngroups=ngroups,
359+
comp_ids=comp_ids,
360+
mask=mask,
361+
**kwargs,
362+
)
363+
364+
# If we ever have more than just "rank" here, we'll need to do
365+
# `if self.how in self.cast_blocklist` like we do for other dtypes.
366+
return res_values
367+
335368
npvalues = self._ea_to_cython_values(values)
336369

337370
res_values = self._cython_op_ndim_compat(
@@ -551,14 +584,16 @@ def _call_cython_op(
551584
else:
552585
# TODO: min_count
553586
if self.uses_mask():
587+
if self.how != "rank":
588+
# TODO: should rank take result_mask?
589+
kwargs["result_mask"] = result_mask
554590
func(
555591
out=result,
556592
values=values,
557593
labels=comp_ids,
558594
ngroups=ngroups,
559595
is_datetimelike=is_datetimelike,
560596
mask=mask,
561-
result_mask=result_mask,
562597
**kwargs,
563598
)
564599
else:

pandas/tests/groupby/test_rank.py

+16
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,8 @@ def test_rank_avg_even_vals(dtype, upper):
458458

459459
result = df.groupby("key").rank()
460460
exp_df = DataFrame([2.5, 2.5, 2.5, 2.5], columns=["val"])
461+
if upper:
462+
exp_df = exp_df.astype("Float64")
461463
tm.assert_frame_equal(result, exp_df)
462464

463465

@@ -663,3 +665,17 @@ def test_non_unique_index():
663665
name="value",
664666
)
665667
tm.assert_series_equal(result, expected)
668+
669+
670+
def test_rank_categorical():
671+
cat = pd.Categorical(["a", "a", "b", np.nan, "c", "b"], ordered=True)
672+
cat2 = pd.Categorical([1, 2, 3, np.nan, 4, 5], ordered=True)
673+
674+
df = DataFrame({"col1": [0, 1, 0, 1, 0, 1], "col2": cat, "col3": cat2})
675+
676+
gb = df.groupby("col1")
677+
678+
res = gb.rank()
679+
680+
expected = df.astype(object).groupby("col1").rank()
681+
tm.assert_frame_equal(res, expected)

0 commit comments

Comments
 (0)