14
14
Hashable ,
15
15
Iterator ,
16
16
Sequence ,
17
+ overload ,
17
18
)
18
19
19
20
import numpy as np
47
48
is_categorical_dtype ,
48
49
is_complex_dtype ,
49
50
is_datetime64_any_dtype ,
50
- is_datetime64tz_dtype ,
51
51
is_extension_array_dtype ,
52
- is_float_dtype ,
53
52
is_integer_dtype ,
54
53
is_numeric_dtype ,
55
- is_period_dtype ,
56
54
is_sparse ,
57
55
is_timedelta64_dtype ,
58
56
needs_i8_conversion ,
59
57
)
58
+ from pandas .core .dtypes .dtypes import ExtensionDtype
60
59
from pandas .core .dtypes .generic import ABCCategoricalIndex
61
60
from pandas .core .dtypes .missing import (
62
61
isna ,
63
62
maybe_fill ,
64
63
)
65
64
66
- from pandas .core .arrays import ExtensionArray
65
+ from pandas .core .arrays import (
66
+ DatetimeArray ,
67
+ ExtensionArray ,
68
+ PeriodArray ,
69
+ TimedeltaArray ,
70
+ )
71
+ from pandas .core .arrays .boolean import BooleanDtype
72
+ from pandas .core .arrays .floating import (
73
+ Float64Dtype ,
74
+ FloatingDtype ,
75
+ )
76
+ from pandas .core .arrays .integer import (
77
+ Int64Dtype ,
78
+ _IntegerDtype ,
79
+ )
67
80
from pandas .core .arrays .masked import (
68
81
BaseMaskedArray ,
69
82
BaseMaskedDtype ,
@@ -194,7 +207,7 @@ def get_cython_func_and_vals(self, values: np.ndarray, is_numeric: bool):
194
207
195
208
return func , values
196
209
197
- def disallow_invalid_ops (self , dtype : DtypeObj , is_numeric : bool = False ):
210
+ def _disallow_invalid_ops (self , dtype : DtypeObj , is_numeric : bool = False ):
198
211
"""
199
212
Check if we can do this operation with our cython functions.
200
213
@@ -230,7 +243,7 @@ def disallow_invalid_ops(self, dtype: DtypeObj, is_numeric: bool = False):
230
243
if how in ["prod" , "cumprod" ]:
231
244
raise TypeError (f"timedelta64 type does not support { how } operations" )
232
245
233
- def get_output_shape (self , ngroups : int , values : np .ndarray ) -> Shape :
246
+ def _get_output_shape (self , ngroups : int , values : np .ndarray ) -> Shape :
234
247
how = self .how
235
248
kind = self .kind
236
249
@@ -261,7 +274,15 @@ def get_out_dtype(self, dtype: np.dtype) -> np.dtype:
261
274
out_dtype = "object"
262
275
return np .dtype (out_dtype )
263
276
264
- def get_result_dtype (self , dtype : DtypeObj ) -> DtypeObj :
277
+ @overload
278
+ def _get_result_dtype (self , dtype : np .dtype ) -> np .dtype :
279
+ ...
280
+
281
+ @overload
282
+ def _get_result_dtype (self , dtype : ExtensionDtype ) -> ExtensionDtype :
283
+ ...
284
+
285
+ def _get_result_dtype (self , dtype : DtypeObj ) -> DtypeObj :
265
286
"""
266
287
Get the desired dtype of a result based on the
267
288
input dtype and how it was computed.
@@ -276,13 +297,6 @@ def get_result_dtype(self, dtype: DtypeObj) -> DtypeObj:
276
297
np.dtype or ExtensionDtype
277
298
The desired dtype of the result.
278
299
"""
279
- from pandas .core .arrays .boolean import BooleanDtype
280
- from pandas .core .arrays .floating import Float64Dtype
281
- from pandas .core .arrays .integer import (
282
- Int64Dtype ,
283
- _IntegerDtype ,
284
- )
285
-
286
300
how = self .how
287
301
288
302
if how in ["add" , "cumsum" , "sum" , "prod" ]:
@@ -315,15 +329,12 @@ def _ea_wrap_cython_operation(
315
329
# TODO: general case implementation overridable by EAs.
316
330
orig_values = values
317
331
318
- if is_datetime64tz_dtype ( values . dtype ) or is_period_dtype ( values . dtype ):
332
+ if isinstance ( orig_values , ( DatetimeArray , PeriodArray ) ):
319
333
# All of the functions implemented here are ordinal, so we can
320
334
# operate on the tz-naive equivalents
321
- npvalues = values .view ("M8[ns]" )
335
+ npvalues = orig_values . _ndarray .view ("M8[ns]" )
322
336
res_values = self ._cython_op_ndim_compat (
323
- # error: Argument 1 to "_cython_op_ndim_compat" of
324
- # "WrappedCythonOp" has incompatible type
325
- # "Union[ExtensionArray, ndarray]"; expected "ndarray"
326
- npvalues , # type: ignore[arg-type]
337
+ npvalues ,
327
338
min_count = min_count ,
328
339
ngroups = ngroups ,
329
340
comp_ids = comp_ids ,
@@ -336,14 +347,31 @@ def _ea_wrap_cython_operation(
336
347
# preserve float64 dtype
337
348
return res_values
338
349
339
- res_values = res_values .astype ("i8" , copy = False )
340
- # error: Too many arguments for "ExtensionArray"
341
- result = type (orig_values )( # type: ignore[call-arg]
342
- res_values , dtype = orig_values .dtype
350
+ res_values = res_values .view ("i8" )
351
+ result = type (orig_values )(res_values , dtype = orig_values .dtype )
352
+ return result
353
+
354
+ elif isinstance (orig_values , TimedeltaArray ):
355
+ # We have an ExtensionArray but not ExtensionDtype
356
+ res_values = self ._cython_op_ndim_compat (
357
+ orig_values ._ndarray ,
358
+ min_count = min_count ,
359
+ ngroups = ngroups ,
360
+ comp_ids = comp_ids ,
361
+ mask = None ,
362
+ ** kwargs ,
343
363
)
364
+ if self .how in ["rank" ]:
365
+ # i.e. how in WrappedCythonOp.cast_blocklist, since
366
+ # other cast_blocklist methods dont go through cython_operation
367
+ # preserve float64 dtype
368
+ return res_values
369
+
370
+ # otherwise res_values has the same dtype as original values
371
+ result = type (orig_values )(res_values )
344
372
return result
345
373
346
- elif is_integer_dtype (values .dtype ) or is_bool_dtype ( values . dtype ):
374
+ elif isinstance (values .dtype , ( BooleanDtype , _IntegerDtype ) ):
347
375
# IntegerArray or BooleanArray
348
376
npvalues = values .to_numpy ("float64" , na_value = np .nan )
349
377
res_values = self ._cython_op_ndim_compat (
@@ -359,17 +387,14 @@ def _ea_wrap_cython_operation(
359
387
# other cast_blocklist methods dont go through cython_operation
360
388
return res_values
361
389
362
- dtype = self .get_result_dtype (orig_values .dtype )
363
- # error: Item "dtype[Any]" of "Union[dtype[Any], ExtensionDtype]"
364
- # has no attribute "construct_array_type"
365
- cls = dtype .construct_array_type () # type: ignore[union-attr]
390
+ dtype = self ._get_result_dtype (orig_values .dtype )
391
+ cls = dtype .construct_array_type ()
366
392
return cls ._from_sequence (res_values , dtype = dtype )
367
393
368
- elif is_float_dtype (values .dtype ):
394
+ elif isinstance (values .dtype , FloatingDtype ):
369
395
# FloatingArray
370
- # error: "ExtensionDtype" has no attribute "numpy_dtype"
371
396
npvalues = values .to_numpy (
372
- values .dtype .numpy_dtype , # type: ignore[attr-defined]
397
+ values .dtype .numpy_dtype ,
373
398
na_value = np .nan ,
374
399
)
375
400
res_values = self ._cython_op_ndim_compat (
@@ -385,10 +410,8 @@ def _ea_wrap_cython_operation(
385
410
# other cast_blocklist methods dont go through cython_operation
386
411
return res_values
387
412
388
- dtype = self .get_result_dtype (orig_values .dtype )
389
- # error: Item "dtype[Any]" of "Union[dtype[Any], ExtensionDtype]"
390
- # has no attribute "construct_array_type"
391
- cls = dtype .construct_array_type () # type: ignore[union-attr]
413
+ dtype = self ._get_result_dtype (orig_values .dtype )
414
+ cls = dtype .construct_array_type ()
392
415
return cls ._from_sequence (res_values , dtype = dtype )
393
416
394
417
raise NotImplementedError (
@@ -422,12 +445,13 @@ def _masked_ea_wrap_cython_operation(
422
445
mask = mask ,
423
446
** kwargs ,
424
447
)
425
- dtype = self .get_result_dtype (orig_values .dtype )
448
+ dtype = self ._get_result_dtype (orig_values .dtype )
426
449
assert isinstance (dtype , BaseMaskedDtype )
427
450
cls = dtype .construct_array_type ()
428
451
429
452
return cls (res_values .astype (dtype .type , copy = False ), mask )
430
453
454
+ @final
431
455
def _cython_op_ndim_compat (
432
456
self ,
433
457
values : np .ndarray ,
@@ -500,7 +524,7 @@ def _call_cython_op(
500
524
if mask is not None :
501
525
mask = mask .reshape (values .shape , order = "C" )
502
526
503
- out_shape = self .get_output_shape (ngroups , values )
527
+ out_shape = self ._get_output_shape (ngroups , values )
504
528
func , values = self .get_cython_func_and_vals (values , is_numeric )
505
529
out_dtype = self .get_out_dtype (values .dtype )
506
530
@@ -550,19 +574,71 @@ def _call_cython_op(
550
574
if self .how not in self .cast_blocklist :
551
575
# e.g. if we are int64 and need to restore to datetime64/timedelta64
552
576
# "rank" is the only member of cast_blocklist we get here
553
- res_dtype = self .get_result_dtype (orig_values .dtype )
554
- # error: Argument 2 to "maybe_downcast_to_dtype" has incompatible type
555
- # "Union[dtype[Any], ExtensionDtype]"; expected "Union[str, dtype[Any]]"
556
- op_result = maybe_downcast_to_dtype (
557
- result , res_dtype # type: ignore[arg-type]
558
- )
577
+ res_dtype = self ._get_result_dtype (orig_values .dtype )
578
+ op_result = maybe_downcast_to_dtype (result , res_dtype )
559
579
else :
560
580
op_result = result
561
581
562
582
# error: Incompatible return value type (got "Union[ExtensionArray, ndarray]",
563
583
# expected "ndarray")
564
584
return op_result # type: ignore[return-value]
565
585
586
+ @final
587
+ def cython_operation (
588
+ self ,
589
+ * ,
590
+ values : ArrayLike ,
591
+ axis : int ,
592
+ min_count : int = - 1 ,
593
+ comp_ids : np .ndarray ,
594
+ ngroups : int ,
595
+ ** kwargs ,
596
+ ) -> ArrayLike :
597
+ """
598
+ Call our cython function, with appropriate pre- and post- processing.
599
+ """
600
+ if values .ndim > 2 :
601
+ raise NotImplementedError ("number of dimensions is currently limited to 2" )
602
+ elif values .ndim == 2 :
603
+ # Note: it is *not* the case that axis is always 0 for 1-dim values,
604
+ # as we can have 1D ExtensionArrays that we need to treat as 2D
605
+ assert axis == 1 , axis
606
+
607
+ dtype = values .dtype
608
+ is_numeric = is_numeric_dtype (dtype )
609
+
610
+ # can we do this operation with our cython functions
611
+ # if not raise NotImplementedError
612
+ self ._disallow_invalid_ops (dtype , is_numeric )
613
+
614
+ if not isinstance (values , np .ndarray ):
615
+ # i.e. ExtensionArray
616
+ if isinstance (values , BaseMaskedArray ) and self .uses_mask ():
617
+ return self ._masked_ea_wrap_cython_operation (
618
+ values ,
619
+ min_count = min_count ,
620
+ ngroups = ngroups ,
621
+ comp_ids = comp_ids ,
622
+ ** kwargs ,
623
+ )
624
+ else :
625
+ return self ._ea_wrap_cython_operation (
626
+ values ,
627
+ min_count = min_count ,
628
+ ngroups = ngroups ,
629
+ comp_ids = comp_ids ,
630
+ ** kwargs ,
631
+ )
632
+
633
+ return self ._cython_op_ndim_compat (
634
+ values ,
635
+ min_count = min_count ,
636
+ ngroups = ngroups ,
637
+ comp_ids = comp_ids ,
638
+ mask = None ,
639
+ ** kwargs ,
640
+ )
641
+
566
642
567
643
class BaseGrouper :
568
644
"""
@@ -799,6 +875,7 @@ def group_info(self):
799
875
800
876
ngroups = len (obs_group_ids )
801
877
comp_ids = ensure_platform_int (comp_ids )
878
+
802
879
return comp_ids , obs_group_ids , ngroups
803
880
804
881
@final
@@ -868,58 +945,23 @@ def _cython_operation(
868
945
how : str ,
869
946
axis : int ,
870
947
min_count : int = - 1 ,
871
- mask : np .ndarray | None = None ,
872
948
** kwargs ,
873
949
) -> ArrayLike :
874
950
"""
875
951
Returns the values of a cython operation.
876
952
"""
877
953
assert kind in ["transform" , "aggregate" ]
878
954
879
- if values .ndim > 2 :
880
- raise NotImplementedError ("number of dimensions is currently limited to 2" )
881
- elif values .ndim == 2 :
882
- # Note: it is *not* the case that axis is always 0 for 1-dim values,
883
- # as we can have 1D ExtensionArrays that we need to treat as 2D
884
- assert axis == 1 , axis
885
-
886
- dtype = values .dtype
887
- is_numeric = is_numeric_dtype (dtype )
888
-
889
955
cy_op = WrappedCythonOp (kind = kind , how = how )
890
956
891
- # can we do this operation with our cython functions
892
- # if not raise NotImplementedError
893
- cy_op .disallow_invalid_ops (dtype , is_numeric )
894
-
895
957
comp_ids , _ , _ = self .group_info
896
958
ngroups = self .ngroups
897
-
898
- func_uses_mask = cy_op .uses_mask ()
899
- if is_extension_array_dtype (dtype ):
900
- if isinstance (values , BaseMaskedArray ) and func_uses_mask :
901
- return cy_op ._masked_ea_wrap_cython_operation (
902
- values ,
903
- min_count = min_count ,
904
- ngroups = ngroups ,
905
- comp_ids = comp_ids ,
906
- ** kwargs ,
907
- )
908
- else :
909
- return cy_op ._ea_wrap_cython_operation (
910
- values ,
911
- min_count = min_count ,
912
- ngroups = ngroups ,
913
- comp_ids = comp_ids ,
914
- ** kwargs ,
915
- )
916
-
917
- return cy_op ._cython_op_ndim_compat (
918
- values ,
959
+ return cy_op .cython_operation (
960
+ values = values ,
961
+ axis = axis ,
919
962
min_count = min_count ,
920
- ngroups = self .ngroups ,
921
963
comp_ids = comp_ids ,
922
- mask = mask ,
964
+ ngroups = ngroups ,
923
965
** kwargs ,
924
966
)
925
967
@@ -969,8 +1011,8 @@ def _aggregate_series_fast(
969
1011
indexer = get_group_index_sorter (group_index , ngroups )
970
1012
obj = obj .take (indexer )
971
1013
group_index = group_index .take (indexer )
972
- grouper = libreduction .SeriesGrouper (obj , func , group_index , ngroups )
973
- result , counts = grouper .get_result ()
1014
+ sgrouper = libreduction .SeriesGrouper (obj , func , group_index , ngroups )
1015
+ result , counts = sgrouper .get_result ()
974
1016
return result , counts
975
1017
976
1018
@final
@@ -1167,8 +1209,8 @@ def _aggregate_series_fast(
1167
1209
# - obj is backed by an ndarray, not ExtensionArray
1168
1210
# - ngroups != 0
1169
1211
# - len(self.bins) > 0
1170
- grouper = libreduction .SeriesBinGrouper (obj , func , self .bins )
1171
- return grouper .get_result ()
1212
+ sbg = libreduction .SeriesBinGrouper (obj , func , self .bins )
1213
+ return sbg .get_result ()
1172
1214
1173
1215
1174
1216
def _is_indexed_like (obj , axes , axis : int ) -> bool :
0 commit comments