24
24
from pandas ._libs import NaT , iNaT , lib
25
25
import pandas ._libs .groupby as libgroupby
26
26
import pandas ._libs .reduction as libreduction
27
- from pandas ._typing import F , FrameOrSeries , Label , Shape
27
+ from pandas ._typing import ArrayLike , F , FrameOrSeries , Label , Shape
28
28
from pandas .errors import AbstractMethodError
29
29
from pandas .util ._decorators import cache_readonly
30
30
@@ -445,6 +445,68 @@ def _get_cython_func_and_vals(
445
445
raise
446
446
return func , values
447
447
448
+ def _disallow_invalid_ops (self , values : ArrayLike , how : str ):
449
+ """
450
+ Check if we can do this operation with our cython functions.
451
+
452
+ Raises
453
+ ------
454
+ NotImplementedError
455
+ This is either not a valid function for this dtype, or
456
+ valid but not implemented in cython.
457
+ """
458
+ dtype = values .dtype
459
+
460
+ if is_categorical_dtype (dtype ) or is_sparse (dtype ):
461
+ # categoricals are only 1d, so we
462
+ # are not setup for dim transforming
463
+ raise NotImplementedError (f"{ dtype } dtype not supported" )
464
+ elif is_datetime64_any_dtype (dtype ):
465
+ # we raise NotImplemented if this is an invalid operation
466
+ # entirely, e.g. adding datetimes
467
+ if how in ["add" , "prod" , "cumsum" , "cumprod" ]:
468
+ raise NotImplementedError (
469
+ f"datetime64 type does not support { how } operations"
470
+ )
471
+ elif is_timedelta64_dtype (dtype ):
472
+ if how in ["prod" , "cumprod" ]:
473
+ raise NotImplementedError (
474
+ f"timedelta64 type does not support { how } operations"
475
+ )
476
+
477
+ def _ea_wrap_cython_operation (
478
+ self , kind : str , values , how : str , axis : int , min_count : int = - 1 , ** kwargs
479
+ ) -> Tuple [np .ndarray , Optional [List [str ]]]:
480
+ """
481
+ If we have an ExtensionArray, unwrap, call _cython_operation, and
482
+ re-wrap if appropriate.
483
+ """
484
+ # TODO: general case implementation overrideable by EAs.
485
+ orig_values = values
486
+
487
+ if is_datetime64tz_dtype (values .dtype ) or is_period_dtype (values .dtype ):
488
+ # All of the functions implemented here are ordinal, so we can
489
+ # operate on the tz-naive equivalents
490
+ values = values .view ("M8[ns]" )
491
+ res_values , names = self ._cython_operation (
492
+ kind , values , how , axis , min_count , ** kwargs
493
+ )
494
+ res_values = res_values .astype ("i8" , copy = False )
495
+ # FIXME: this is wrong for rank, but not tested.
496
+ result = type (orig_values )._simple_new (res_values , dtype = orig_values .dtype )
497
+ return result , names
498
+
499
+ elif is_integer_dtype (values .dtype ) or is_bool_dtype (values .dtype ):
500
+ # IntegerArray or BooleanArray
501
+ values = ensure_int_or_float (values )
502
+ res_values , names = self ._cython_operation (
503
+ kind , values , how , axis , min_count , ** kwargs
504
+ )
505
+ result = maybe_cast_result (result = res_values , obj = orig_values , how = how )
506
+ return result , names
507
+
508
+ raise NotImplementedError (values .dtype )
509
+
448
510
def _cython_operation (
449
511
self , kind : str , values , how : str , axis : int , min_count : int = - 1 , ** kwargs
450
512
) -> Tuple [np .ndarray , Optional [List [str ]]]:
@@ -454,8 +516,8 @@ def _cython_operation(
454
516
Names is only useful when dealing with 2D results, like ohlc
455
517
(see self._name_functions).
456
518
"""
457
- assert kind in ["transform" , "aggregate" ]
458
519
orig_values = values
520
+ assert kind in ["transform" , "aggregate" ]
459
521
460
522
if values .ndim > 2 :
461
523
raise NotImplementedError ("number of dimensions is currently limited to 2" )
@@ -466,30 +528,12 @@ def _cython_operation(
466
528
467
529
# can we do this operation with our cython functions
468
530
# if not raise NotImplementedError
531
+ self ._disallow_invalid_ops (values , how )
469
532
470
- # we raise NotImplemented if this is an invalid operation
471
- # entirely, e.g. adding datetimes
472
-
473
- # categoricals are only 1d, so we
474
- # are not setup for dim transforming
475
- if is_categorical_dtype (values .dtype ) or is_sparse (values .dtype ):
476
- raise NotImplementedError (f"{ values .dtype } dtype not supported" )
477
- elif is_datetime64_any_dtype (values .dtype ):
478
- if how in ["add" , "prod" , "cumsum" , "cumprod" ]:
479
- raise NotImplementedError (
480
- f"datetime64 type does not support { how } operations"
481
- )
482
- elif is_timedelta64_dtype (values .dtype ):
483
- if how in ["prod" , "cumprod" ]:
484
- raise NotImplementedError (
485
- f"timedelta64 type does not support { how } operations"
486
- )
487
-
488
- if is_datetime64tz_dtype (values .dtype ):
489
- # Cast to naive; we'll cast back at the end of the function
490
- # TODO: possible need to reshape?
491
- # TODO(EA2D):kludge can be avoided when 2D EA is allowed.
492
- values = values .view ("M8[ns]" )
533
+ if is_extension_array_dtype (values .dtype ):
534
+ return self ._ea_wrap_cython_operation (
535
+ kind , values , how , axis , min_count , ** kwargs
536
+ )
493
537
494
538
is_datetimelike = needs_i8_conversion (values .dtype )
495
539
is_numeric = is_numeric_dtype (values .dtype )
@@ -573,19 +617,9 @@ def _cython_operation(
573
617
if swapped :
574
618
result = result .swapaxes (0 , axis )
575
619
576
- if is_datetime64tz_dtype (orig_values .dtype ) or is_period_dtype (
577
- orig_values .dtype
578
- ):
579
- # We need to use the constructors directly for these dtypes
580
- # since numpy won't recognize them
581
- # https://github.com/pandas-dev/pandas/issues/31471
582
- result = type (orig_values )(result .astype (np .int64 ), dtype = orig_values .dtype )
583
- elif is_datetimelike and kind == "aggregate" :
620
+ if is_datetimelike and kind == "aggregate" :
584
621
result = result .astype (orig_values .dtype )
585
622
586
- if is_extension_array_dtype (orig_values .dtype ):
587
- result = maybe_cast_result (result = result , obj = orig_values , how = how )
588
-
589
623
return result , names
590
624
591
625
def _aggregate (
0 commit comments