Skip to content

Commit 40ca2b9

Browse files
authored
REF: implement _ea_wrap_cython_operation (#38162)
1 parent c4c1dc3 commit 40ca2b9

File tree

1 file changed

+70
-36
lines changed

1 file changed

+70
-36
lines changed

pandas/core/groupby/ops.py

+70-36
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from pandas._libs import NaT, iNaT, lib
2525
import pandas._libs.groupby as libgroupby
2626
import pandas._libs.reduction as libreduction
27-
from pandas._typing import F, FrameOrSeries, Label, Shape
27+
from pandas._typing import ArrayLike, F, FrameOrSeries, Label, Shape
2828
from pandas.errors import AbstractMethodError
2929
from pandas.util._decorators import cache_readonly
3030

@@ -445,6 +445,68 @@ def _get_cython_func_and_vals(
445445
raise
446446
return func, values
447447

448+
def _disallow_invalid_ops(self, values: ArrayLike, how: str):
449+
"""
450+
Check if we can do this operation with our cython functions.
451+
452+
Raises
453+
------
454+
NotImplementedError
455+
This is either not a valid function for this dtype, or
456+
valid but not implemented in cython.
457+
"""
458+
dtype = values.dtype
459+
460+
if is_categorical_dtype(dtype) or is_sparse(dtype):
461+
# categoricals are only 1d, so we
462+
# are not setup for dim transforming
463+
raise NotImplementedError(f"{dtype} dtype not supported")
464+
elif is_datetime64_any_dtype(dtype):
465+
# we raise NotImplemented if this is an invalid operation
466+
# entirely, e.g. adding datetimes
467+
if how in ["add", "prod", "cumsum", "cumprod"]:
468+
raise NotImplementedError(
469+
f"datetime64 type does not support {how} operations"
470+
)
471+
elif is_timedelta64_dtype(dtype):
472+
if how in ["prod", "cumprod"]:
473+
raise NotImplementedError(
474+
f"timedelta64 type does not support {how} operations"
475+
)
476+
477+
def _ea_wrap_cython_operation(
478+
self, kind: str, values, how: str, axis: int, min_count: int = -1, **kwargs
479+
) -> Tuple[np.ndarray, Optional[List[str]]]:
480+
"""
481+
If we have an ExtensionArray, unwrap, call _cython_operation, and
482+
re-wrap if appropriate.
483+
"""
484+
# TODO: general case implementation overrideable by EAs.
485+
orig_values = values
486+
487+
if is_datetime64tz_dtype(values.dtype) or is_period_dtype(values.dtype):
488+
# All of the functions implemented here are ordinal, so we can
489+
# operate on the tz-naive equivalents
490+
values = values.view("M8[ns]")
491+
res_values, names = self._cython_operation(
492+
kind, values, how, axis, min_count, **kwargs
493+
)
494+
res_values = res_values.astype("i8", copy=False)
495+
# FIXME: this is wrong for rank, but not tested.
496+
result = type(orig_values)._simple_new(res_values, dtype=orig_values.dtype)
497+
return result, names
498+
499+
elif is_integer_dtype(values.dtype) or is_bool_dtype(values.dtype):
500+
# IntegerArray or BooleanArray
501+
values = ensure_int_or_float(values)
502+
res_values, names = self._cython_operation(
503+
kind, values, how, axis, min_count, **kwargs
504+
)
505+
result = maybe_cast_result(result=res_values, obj=orig_values, how=how)
506+
return result, names
507+
508+
raise NotImplementedError(values.dtype)
509+
448510
def _cython_operation(
449511
self, kind: str, values, how: str, axis: int, min_count: int = -1, **kwargs
450512
) -> Tuple[np.ndarray, Optional[List[str]]]:
@@ -454,8 +516,8 @@ def _cython_operation(
454516
Names is only useful when dealing with 2D results, like ohlc
455517
(see self._name_functions).
456518
"""
457-
assert kind in ["transform", "aggregate"]
458519
orig_values = values
520+
assert kind in ["transform", "aggregate"]
459521

460522
if values.ndim > 2:
461523
raise NotImplementedError("number of dimensions is currently limited to 2")
@@ -466,30 +528,12 @@ def _cython_operation(
466528

467529
# can we do this operation with our cython functions
468530
# if not raise NotImplementedError
531+
self._disallow_invalid_ops(values, how)
469532

470-
# we raise NotImplemented if this is an invalid operation
471-
# entirely, e.g. adding datetimes
472-
473-
# categoricals are only 1d, so we
474-
# are not setup for dim transforming
475-
if is_categorical_dtype(values.dtype) or is_sparse(values.dtype):
476-
raise NotImplementedError(f"{values.dtype} dtype not supported")
477-
elif is_datetime64_any_dtype(values.dtype):
478-
if how in ["add", "prod", "cumsum", "cumprod"]:
479-
raise NotImplementedError(
480-
f"datetime64 type does not support {how} operations"
481-
)
482-
elif is_timedelta64_dtype(values.dtype):
483-
if how in ["prod", "cumprod"]:
484-
raise NotImplementedError(
485-
f"timedelta64 type does not support {how} operations"
486-
)
487-
488-
if is_datetime64tz_dtype(values.dtype):
489-
# Cast to naive; we'll cast back at the end of the function
490-
# TODO: possible need to reshape?
491-
# TODO(EA2D):kludge can be avoided when 2D EA is allowed.
492-
values = values.view("M8[ns]")
533+
if is_extension_array_dtype(values.dtype):
534+
return self._ea_wrap_cython_operation(
535+
kind, values, how, axis, min_count, **kwargs
536+
)
493537

494538
is_datetimelike = needs_i8_conversion(values.dtype)
495539
is_numeric = is_numeric_dtype(values.dtype)
@@ -573,19 +617,9 @@ def _cython_operation(
573617
if swapped:
574618
result = result.swapaxes(0, axis)
575619

576-
if is_datetime64tz_dtype(orig_values.dtype) or is_period_dtype(
577-
orig_values.dtype
578-
):
579-
# We need to use the constructors directly for these dtypes
580-
# since numpy won't recognize them
581-
# https://github.com/pandas-dev/pandas/issues/31471
582-
result = type(orig_values)(result.astype(np.int64), dtype=orig_values.dtype)
583-
elif is_datetimelike and kind == "aggregate":
620+
if is_datetimelike and kind == "aggregate":
584621
result = result.astype(orig_values.dtype)
585622

586-
if is_extension_array_dtype(orig_values.dtype):
587-
result = maybe_cast_result(result=result, obj=orig_values, how=how)
588-
589623
return result, names
590624

591625
def _aggregate(

0 commit comments

Comments
 (0)