97
97
get_indexer_dict ,
98
98
)
99
99
100
- _CYTHON_FUNCTIONS = {
101
- "aggregate" : {
102
- "add" : "group_add" ,
103
- "prod" : "group_prod" ,
104
- "min" : "group_min" ,
105
- "max" : "group_max" ,
106
- "mean" : "group_mean" ,
107
- "median" : "group_median" ,
108
- "var" : "group_var" ,
109
- "first" : "group_nth" ,
110
- "last" : "group_last" ,
111
- "ohlc" : "group_ohlc" ,
112
- },
113
- "transform" : {
114
- "cumprod" : "group_cumprod" ,
115
- "cumsum" : "group_cumsum" ,
116
- "cummin" : "group_cummin" ,
117
- "cummax" : "group_cummax" ,
118
- "rank" : "group_rank" ,
119
- },
120
- }
121
-
122
-
123
- @functools .lru_cache (maxsize = None )
124
- def _get_cython_function (kind : str , how : str , dtype : np .dtype , is_numeric : bool ):
125
-
126
- dtype_str = dtype .name
127
- ftype = _CYTHON_FUNCTIONS [kind ][how ]
128
-
129
- # see if there is a fused-type version of function
130
- # only valid for numeric
131
- f = getattr (libgroupby , ftype , None )
132
- if f is not None :
133
- if is_numeric :
134
- return f
135
- elif dtype == object :
136
- if "object" not in f .__signatures__ :
137
- # raise NotImplementedError here rather than TypeError later
100
+
101
+ class WrappedCythonOp :
102
+ """
103
+ Dispatch logic for functions defined in _libs.groupby
104
+ """
105
+
106
+ def __init__ (self , kind : str , how : str ):
107
+ self .kind = kind
108
+ self .how = how
109
+
110
+ _CYTHON_FUNCTIONS = {
111
+ "aggregate" : {
112
+ "add" : "group_add" ,
113
+ "prod" : "group_prod" ,
114
+ "min" : "group_min" ,
115
+ "max" : "group_max" ,
116
+ "mean" : "group_mean" ,
117
+ "median" : "group_median" ,
118
+ "var" : "group_var" ,
119
+ "first" : "group_nth" ,
120
+ "last" : "group_last" ,
121
+ "ohlc" : "group_ohlc" ,
122
+ },
123
+ "transform" : {
124
+ "cumprod" : "group_cumprod" ,
125
+ "cumsum" : "group_cumsum" ,
126
+ "cummin" : "group_cummin" ,
127
+ "cummax" : "group_cummax" ,
128
+ "rank" : "group_rank" ,
129
+ },
130
+ }
131
+
132
+ _cython_arity = {"ohlc" : 4 } # OHLC
133
+
134
+ # Note: we make this a classmethod and pass kind+how so that caching
135
+ # works at the class level and not the instance level
136
+ @classmethod
137
+ @functools .lru_cache (maxsize = None )
138
+ def _get_cython_function (
139
+ cls , kind : str , how : str , dtype : np .dtype , is_numeric : bool
140
+ ):
141
+
142
+ dtype_str = dtype .name
143
+ ftype = cls ._CYTHON_FUNCTIONS [kind ][how ]
144
+
145
+ # see if there is a fused-type version of function
146
+ # only valid for numeric
147
+ f = getattr (libgroupby , ftype , None )
148
+ if f is not None :
149
+ if is_numeric :
150
+ return f
151
+ elif dtype == object :
152
+ if "object" not in f .__signatures__ :
153
+ # raise NotImplementedError here rather than TypeError later
154
+ raise NotImplementedError (
155
+ f"function is not implemented for this dtype: "
156
+ f"[how->{ how } ,dtype->{ dtype_str } ]"
157
+ )
158
+ return f
159
+
160
+ raise NotImplementedError (
161
+ f"function is not implemented for this dtype: "
162
+ f"[how->{ how } ,dtype->{ dtype_str } ]"
163
+ )
164
+
165
+ def get_cython_func_and_vals (self , values : np .ndarray , is_numeric : bool ):
166
+ """
167
+ Find the appropriate cython function, casting if necessary.
168
+
169
+ Parameters
170
+ ----------
171
+ values : np.ndarray
172
+ is_numeric : bool
173
+
174
+ Returns
175
+ -------
176
+ func : callable
177
+ values : np.ndarray
178
+ """
179
+ how = self .how
180
+ kind = self .kind
181
+
182
+ if how in ["median" , "cumprod" ]:
183
+ # these two only have float64 implementations
184
+ if is_numeric :
185
+ values = ensure_float64 (values )
186
+ else :
138
187
raise NotImplementedError (
139
188
f"function is not implemented for this dtype: "
140
- f"[how->{ how } ,dtype->{ dtype_str } ]"
189
+ f"[how->{ how } ,dtype->{ values . dtype . name } ]"
141
190
)
142
- return f
191
+ func = getattr (libgroupby , f"group_{ how } _float64" )
192
+ return func , values
143
193
144
- raise NotImplementedError (
145
- f"function is not implemented for this dtype: "
146
- f"[how->{ how } ,dtype->{ dtype_str } ]"
147
- )
194
+ func = self ._get_cython_function (kind , how , values .dtype , is_numeric )
195
+
196
+ if values .dtype .kind in ["i" , "u" ]:
197
+ if how in ["add" , "var" , "prod" , "mean" , "ohlc" ]:
198
+ # result may still include NaN, so we have to cast
199
+ values = ensure_float64 (values )
200
+
201
+ return func , values
202
+
203
+ def disallow_invalid_ops (self , dtype : DtypeObj , is_numeric : bool = False ):
204
+ """
205
+ Check if we can do this operation with our cython functions.
206
+
207
+ Raises
208
+ ------
209
+ NotImplementedError
210
+ This is either not a valid function for this dtype, or
211
+ valid but not implemented in cython.
212
+ """
213
+ how = self .how
214
+
215
+ if is_numeric :
216
+ # never an invalid op for those dtypes, so return early as fastpath
217
+ return
218
+
219
+ if is_categorical_dtype (dtype ) or is_sparse (dtype ):
220
+ # categoricals are only 1d, so we
221
+ # are not setup for dim transforming
222
+ raise NotImplementedError (f"{ dtype } dtype not supported" )
223
+ elif is_datetime64_any_dtype (dtype ):
224
+ # we raise NotImplemented if this is an invalid operation
225
+ # entirely, e.g. adding datetimes
226
+ if how in ["add" , "prod" , "cumsum" , "cumprod" ]:
227
+ raise NotImplementedError (
228
+ f"datetime64 type does not support { how } operations"
229
+ )
230
+ elif is_timedelta64_dtype (dtype ):
231
+ if how in ["prod" , "cumprod" ]:
232
+ raise NotImplementedError (
233
+ f"timedelta64 type does not support { how } operations"
234
+ )
235
+
236
+ def get_output_shape (self , ngroups : int , values : np .ndarray ) -> Shape :
237
+ how = self .how
238
+ kind = self .kind
239
+
240
+ arity = self ._cython_arity .get (how , 1 )
241
+
242
+ out_shape : Shape
243
+ if how == "ohlc" :
244
+ out_shape = (ngroups , 4 )
245
+ elif arity > 1 :
246
+ raise NotImplementedError (
247
+ "arity of more than 1 is not supported for the 'how' argument"
248
+ )
249
+ elif kind == "transform" :
250
+ out_shape = values .shape
251
+ else :
252
+ out_shape = (ngroups ,) + values .shape [1 :]
253
+ return out_shape
254
+
255
+ def get_out_dtype (self , dtype : np .dtype ) -> np .dtype :
256
+ how = self .how
257
+
258
+ if how == "rank" :
259
+ out_dtype = "float64"
260
+ else :
261
+ if is_numeric_dtype (dtype ):
262
+ out_dtype = f"{ dtype .kind } { dtype .itemsize } "
263
+ else :
264
+ out_dtype = "object"
265
+ return np .dtype (out_dtype )
148
266
149
267
150
268
class BaseGrouper :
@@ -437,8 +555,6 @@ def get_group_levels(self) -> List[Index]:
437
555
# ------------------------------------------------------------
438
556
# Aggregation functions
439
557
440
- _cython_arity = {"ohlc" : 4 } # OHLC
441
-
442
558
@final
443
559
def _is_builtin_func (self , arg ):
444
560
"""
@@ -447,80 +563,6 @@ def _is_builtin_func(self, arg):
447
563
"""
448
564
return SelectionMixin ._builtin_table .get (arg , arg )
449
565
450
- @final
451
- def _get_cython_func_and_vals (
452
- self , kind : str , how : str , values : np .ndarray , is_numeric : bool
453
- ):
454
- """
455
- Find the appropriate cython function, casting if necessary.
456
-
457
- Parameters
458
- ----------
459
- kind : str
460
- how : str
461
- values : np.ndarray
462
- is_numeric : bool
463
-
464
- Returns
465
- -------
466
- func : callable
467
- values : np.ndarray
468
- """
469
- if how in ["median" , "cumprod" ]:
470
- # these two only have float64 implementations
471
- if is_numeric :
472
- values = ensure_float64 (values )
473
- else :
474
- raise NotImplementedError (
475
- f"function is not implemented for this dtype: "
476
- f"[how->{ how } ,dtype->{ values .dtype .name } ]"
477
- )
478
- func = getattr (libgroupby , f"group_{ how } _float64" )
479
- return func , values
480
-
481
- func = _get_cython_function (kind , how , values .dtype , is_numeric )
482
-
483
- if values .dtype .kind in ["i" , "u" ]:
484
- if how in ["add" , "var" , "prod" , "mean" , "ohlc" ]:
485
- # result may still include NaN, so we have to cast
486
- values = ensure_float64 (values )
487
-
488
- return func , values
489
-
490
- @final
491
- def _disallow_invalid_ops (
492
- self , dtype : DtypeObj , how : str , is_numeric : bool = False
493
- ):
494
- """
495
- Check if we can do this operation with our cython functions.
496
-
497
- Raises
498
- ------
499
- NotImplementedError
500
- This is either not a valid function for this dtype, or
501
- valid but not implemented in cython.
502
- """
503
- if is_numeric :
504
- # never an invalid op for those dtypes, so return early as fastpath
505
- return
506
-
507
- if is_categorical_dtype (dtype ) or is_sparse (dtype ):
508
- # categoricals are only 1d, so we
509
- # are not setup for dim transforming
510
- raise NotImplementedError (f"{ dtype } dtype not supported" )
511
- elif is_datetime64_any_dtype (dtype ):
512
- # we raise NotImplemented if this is an invalid operation
513
- # entirely, e.g. adding datetimes
514
- if how in ["add" , "prod" , "cumsum" , "cumprod" ]:
515
- raise NotImplementedError (
516
- f"datetime64 type does not support { how } operations"
517
- )
518
- elif is_timedelta64_dtype (dtype ):
519
- if how in ["prod" , "cumprod" ]:
520
- raise NotImplementedError (
521
- f"timedelta64 type does not support { how } operations"
522
- )
523
-
524
566
@final
525
567
def _ea_wrap_cython_operation (
526
568
self , kind : str , values , how : str , axis : int , min_count : int = - 1 , ** kwargs
@@ -593,9 +635,11 @@ def _cython_operation(
593
635
dtype = values .dtype
594
636
is_numeric = is_numeric_dtype (dtype )
595
637
638
+ cy_op = WrappedCythonOp (kind = kind , how = how )
639
+
596
640
# can we do this operation with our cython functions
597
641
# if not raise NotImplementedError
598
- self . _disallow_invalid_ops (dtype , how , is_numeric )
642
+ cy_op . disallow_invalid_ops (dtype , is_numeric )
599
643
600
644
if is_extension_array_dtype (dtype ):
601
645
return self ._ea_wrap_cython_operation (
@@ -637,43 +681,23 @@ def _cython_operation(
637
681
if not is_complex_dtype (dtype ):
638
682
values = ensure_float64 (values )
639
683
640
- arity = self ._cython_arity .get (how , 1 )
641
684
ngroups = self .ngroups
685
+ comp_ids , _ , _ = self .group_info
642
686
643
687
assert axis == 1
644
688
values = values .T
645
- if how == "ohlc" :
646
- out_shape = (ngroups , 4 )
647
- elif arity > 1 :
648
- raise NotImplementedError (
649
- "arity of more than 1 is not supported for the 'how' argument"
650
- )
651
- elif kind == "transform" :
652
- out_shape = values .shape
653
- else :
654
- out_shape = (ngroups ,) + values .shape [1 :]
655
-
656
- func , values = self ._get_cython_func_and_vals (kind , how , values , is_numeric )
657
-
658
- if how == "rank" :
659
- out_dtype = "float"
660
- else :
661
- if is_numeric :
662
- out_dtype = f"{ values .dtype .kind } { values .dtype .itemsize } "
663
- else :
664
- out_dtype = "object"
665
689
666
- codes , _ , _ = self .group_info
690
+ out_shape = cy_op .get_output_shape (ngroups , values )
691
+ func , values = cy_op .get_cython_func_and_vals (values , is_numeric )
692
+ out_dtype = cy_op .get_out_dtype (values .dtype )
667
693
668
694
result = maybe_fill (np .empty (out_shape , dtype = out_dtype ))
669
695
if kind == "aggregate" :
670
- counts = np .zeros (self . ngroups , dtype = np .int64 )
671
- result = self . _aggregate (result , counts , values , codes , func , min_count )
696
+ counts = np .zeros (ngroups , dtype = np .int64 )
697
+ func (result , counts , values , comp_ids , min_count )
672
698
elif kind == "transform" :
673
699
# TODO: min_count
674
- result = self ._transform (
675
- result , values , codes , func , is_datetimelike , ** kwargs
676
- )
700
+ func (result , values , comp_ids , ngroups , is_datetimelike , ** kwargs )
677
701
678
702
if is_integer_dtype (result .dtype ) and not is_datetimelike :
679
703
mask = result == iNaT
@@ -697,28 +721,6 @@ def _cython_operation(
697
721
698
722
return op_result
699
723
700
- @final
701
- def _aggregate (
702
- self , result , counts , values , comp_ids , agg_func , min_count : int = - 1
703
- ):
704
- if agg_func is libgroupby .group_nth :
705
- # different signature from the others
706
- agg_func (result , counts , values , comp_ids , min_count , rank = 1 )
707
- else :
708
- agg_func (result , counts , values , comp_ids , min_count )
709
-
710
- return result
711
-
712
- @final
713
- def _transform (
714
- self , result , values , comp_ids , transform_func , is_datetimelike : bool , ** kwargs
715
- ):
716
-
717
- _ , _ , ngroups = self .group_info
718
- transform_func (result , values , comp_ids , ngroups , is_datetimelike , ** kwargs )
719
-
720
- return result
721
-
722
724
def agg_series (self , obj : Series , func : F ):
723
725
# Caller is responsible for checking ngroups != 0
724
726
assert self .ngroups != 0
0 commit comments