31
31
import pandas ._libs .groupby as libgroupby
32
32
import pandas ._libs .reduction as libreduction
33
33
from pandas ._typing import (
34
- ArrayLike ,
34
+ DtypeObj ,
35
35
F ,
36
36
FrameOrSeries ,
37
37
Shape ,
46
46
maybe_downcast_to_dtype ,
47
47
)
48
48
from pandas .core .dtypes .common import (
49
- ensure_float ,
50
49
ensure_float64 ,
51
50
ensure_int64 ,
52
51
ensure_int_or_float ,
@@ -491,7 +490,9 @@ def _get_cython_func_and_vals(
491
490
return func , values
492
491
493
492
@final
494
- def _disallow_invalid_ops (self , values : ArrayLike , how : str ):
493
+ def _disallow_invalid_ops (
494
+ self , dtype : DtypeObj , how : str , is_numeric : bool = False
495
+ ):
495
496
"""
496
497
Check if we can do this operation with our cython functions.
497
498
@@ -501,7 +502,9 @@ def _disallow_invalid_ops(self, values: ArrayLike, how: str):
501
502
This is either not a valid function for this dtype, or
502
503
valid but not implemented in cython.
503
504
"""
504
- dtype = values .dtype
505
+ if is_numeric :
506
+ # never an invalid op for those dtypes, so return early as fastpath
507
+ return
505
508
506
509
if is_categorical_dtype (dtype ) or is_sparse (dtype ):
507
510
# categoricals are only 1d, so we
@@ -589,32 +592,34 @@ def _cython_operation(
589
592
# as we can have 1D ExtensionArrays that we need to treat as 2D
590
593
assert axis == 1 , axis
591
594
595
+ dtype = values .dtype
596
+ is_numeric = is_numeric_dtype (dtype )
597
+
592
598
# can we do this operation with our cython functions
593
599
# if not raise NotImplementedError
594
- self ._disallow_invalid_ops (values , how )
600
+ self ._disallow_invalid_ops (dtype , how , is_numeric )
595
601
596
- if is_extension_array_dtype (values . dtype ):
602
+ if is_extension_array_dtype (dtype ):
597
603
return self ._ea_wrap_cython_operation (
598
604
kind , values , how , axis , min_count , ** kwargs
599
605
)
600
606
601
- is_datetimelike = needs_i8_conversion (values .dtype )
602
- is_numeric = is_numeric_dtype (values .dtype )
607
+ is_datetimelike = needs_i8_conversion (dtype )
603
608
604
609
if is_datetimelike :
605
610
values = values .view ("int64" )
606
611
is_numeric = True
607
- elif is_bool_dtype (values . dtype ):
612
+ elif is_bool_dtype (dtype ):
608
613
values = ensure_int_or_float (values )
609
- elif is_integer_dtype (values ):
614
+ elif is_integer_dtype (dtype ):
610
615
# we use iNaT for the missing value on ints
611
616
# so pre-convert to guard this condition
612
617
if (values == iNaT ).any ():
613
618
values = ensure_float64 (values )
614
619
else :
615
620
values = ensure_int_or_float (values )
616
- elif is_numeric and not is_complex_dtype (values ):
617
- values = ensure_float64 (ensure_float ( values ) )
621
+ elif is_numeric and not is_complex_dtype (dtype ):
622
+ values = ensure_float64 (values )
618
623
else :
619
624
values = values .astype (object )
620
625
@@ -649,20 +654,18 @@ def _cython_operation(
649
654
codes , _ , _ = self .group_info
650
655
651
656
if kind == "aggregate" :
652
- result = maybe_fill (np .empty (out_shape , dtype = out_dtype ), fill_value = np . nan )
657
+ result = maybe_fill (np .empty (out_shape , dtype = out_dtype ))
653
658
counts = np .zeros (self .ngroups , dtype = np .int64 )
654
659
result = self ._aggregate (result , counts , values , codes , func , min_count )
655
660
elif kind == "transform" :
656
- result = maybe_fill (
657
- np .empty (values .shape , dtype = out_dtype ), fill_value = np .nan
658
- )
661
+ result = maybe_fill (np .empty (values .shape , dtype = out_dtype ))
659
662
660
663
# TODO: min_count
661
664
result = self ._transform (
662
665
result , values , codes , func , is_datetimelike , ** kwargs
663
666
)
664
667
665
- if is_integer_dtype (result ) and not is_datetimelike :
668
+ if is_integer_dtype (result . dtype ) and not is_datetimelike :
666
669
mask = result == iNaT
667
670
if mask .any ():
668
671
result = result .astype ("float64" )
@@ -682,9 +685,9 @@ def _cython_operation(
682
685
# e.g. if we are int64 and need to restore to datetime64/timedelta64
683
686
# "rank" is the only member of cython_cast_blocklist we get here
684
687
dtype = maybe_cast_result_dtype (orig_values .dtype , how )
685
- # error: Argument 2 to "maybe_downcast_to_dtype" has incompatible type
686
- # "Union[dtype[Any], ExtensionDtype]"; expected "Union[str, dtype[Any]]"
687
- result = maybe_downcast_to_dtype (result , dtype ) # type: ignore[arg-type ]
688
+ # error: Incompatible types in assignment (expression has type
689
+ # "Union[ExtensionArray, ndarray]", variable has type "ndarray")
690
+ result = maybe_downcast_to_dtype (result , dtype ) # type: ignore[assignment ]
688
691
689
692
return result
690
693
0 commit comments