|
46 | 46 | ensure_platform_int,
|
47 | 47 | is_1d_only_ea_dtype,
|
48 | 48 | is_bool_dtype,
|
49 |
| - is_categorical_dtype, |
50 | 49 | is_complex_dtype,
|
51 | 50 | is_datetime64_any_dtype,
|
52 | 51 | is_float_dtype,
|
|
56 | 55 | is_timedelta64_dtype,
|
57 | 56 | needs_i8_conversion,
|
58 | 57 | )
|
| 58 | +from pandas.core.dtypes.dtypes import CategoricalDtype |
59 | 59 | from pandas.core.dtypes.missing import (
|
60 | 60 | isna,
|
61 | 61 | maybe_fill,
|
62 | 62 | )
|
63 | 63 |
|
64 | 64 | from pandas.core.arrays import (
|
| 65 | + Categorical, |
65 | 66 | DatetimeArray,
|
66 | 67 | ExtensionArray,
|
67 | 68 | PeriodArray,
|
@@ -142,7 +143,15 @@ def __init__(self, kind: str, how: str, has_dropped_na: bool) -> None:
|
142 | 143 |
|
143 | 144 | # "group_any" and "group_all" are also support masks, but don't go
|
144 | 145 | # through WrappedCythonOp
|
145 |
| - _MASKED_CYTHON_FUNCTIONS = {"cummin", "cummax", "min", "max", "last", "first"} |
| 146 | + _MASKED_CYTHON_FUNCTIONS = { |
| 147 | + "cummin", |
| 148 | + "cummax", |
| 149 | + "min", |
| 150 | + "max", |
| 151 | + "last", |
| 152 | + "first", |
| 153 | + "rank", |
| 154 | + } |
146 | 155 |
|
147 | 156 | _cython_arity = {"ohlc": 4} # OHLC
|
148 | 157 |
|
@@ -229,12 +238,17 @@ def _disallow_invalid_ops(self, dtype: DtypeObj, is_numeric: bool = False):
|
229 | 238 | # never an invalid op for those dtypes, so return early as fastpath
|
230 | 239 | return
|
231 | 240 |
|
232 |
| - if is_categorical_dtype(dtype): |
| 241 | + if isinstance(dtype, CategoricalDtype): |
233 | 242 | # NotImplementedError for methods that can fall back to a
|
234 | 243 | # non-cython implementation.
|
235 | 244 | if how in ["add", "prod", "cumsum", "cumprod"]:
|
236 | 245 | raise TypeError(f"{dtype} type does not support {how} operations")
|
237 |
| - raise NotImplementedError(f"{dtype} dtype not supported") |
| 246 | + elif how not in ["rank"]: |
| 247 | + # only "rank" is implemented in cython |
| 248 | + raise NotImplementedError(f"{dtype} dtype not supported") |
| 249 | + elif not dtype.ordered: |
| 250 | + # TODO: TypeError? |
| 251 | + raise NotImplementedError(f"{dtype} dtype not supported") |
238 | 252 |
|
239 | 253 | elif is_sparse(dtype):
|
240 | 254 | # categoricals are only 1d, so we
|
@@ -332,6 +346,25 @@ def _ea_wrap_cython_operation(
|
332 | 346 | **kwargs,
|
333 | 347 | )
|
334 | 348 |
|
| 349 | + elif isinstance(values, Categorical) and self.uses_mask(): |
| 350 | + assert self.how == "rank" # the only one implemented ATM |
| 351 | + assert values.ordered # checked earlier |
| 352 | + mask = values.isna() |
| 353 | + npvalues = values._ndarray |
| 354 | + |
| 355 | + res_values = self._cython_op_ndim_compat( |
| 356 | + npvalues, |
| 357 | + min_count=min_count, |
| 358 | + ngroups=ngroups, |
| 359 | + comp_ids=comp_ids, |
| 360 | + mask=mask, |
| 361 | + **kwargs, |
| 362 | + ) |
| 363 | + |
| 364 | + # If we ever have more than just "rank" here, we'll need to do |
| 365 | + # `if self.how in self.cast_blocklist` like we do for other dtypes. |
| 366 | + return res_values |
| 367 | + |
335 | 368 | npvalues = self._ea_to_cython_values(values)
|
336 | 369 |
|
337 | 370 | res_values = self._cython_op_ndim_compat(
|
@@ -551,14 +584,16 @@ def _call_cython_op(
|
551 | 584 | else:
|
552 | 585 | # TODO: min_count
|
553 | 586 | if self.uses_mask():
|
| 587 | + if self.how != "rank": |
| 588 | + # TODO: should rank take result_mask? |
| 589 | + kwargs["result_mask"] = result_mask |
554 | 590 | func(
|
555 | 591 | out=result,
|
556 | 592 | values=values,
|
557 | 593 | labels=comp_ids,
|
558 | 594 | ngroups=ngroups,
|
559 | 595 | is_datetimelike=is_datetimelike,
|
560 | 596 | mask=mask,
|
561 |
| - result_mask=result_mask, |
562 | 597 | **kwargs,
|
563 | 598 | )
|
564 | 599 | else:
|
|
0 commit comments